diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 7045d8810493e..bbed80ebe8476 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -8,7 +8,8 @@ template = """

Links for vLLM

- {wheel}
+ {x86_wheel}
+ {arm_wheel}
""" @@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel) with open("index.html", "w") as f: print(f"Generated index.html for {args.wheel}") + # sync the abi tag with .buildkite/scripts/upload-wheels.sh + if "x86_64" in filename: + x86_wheel = filename + arm_wheel = filename.replace("x86_64", "aarch64").replace( + "manylinux1", "manylinux2014" + ) + elif "aarch64" in filename: + x86_wheel = filename.replace("aarch64", "x86_64").replace( + "manylinux2014", "manylinux1" + ) + arm_wheel = filename + else: + raise ValueError(f"Unsupported wheel: {filename}") # cloudfront requires escaping the '+' character f.write( - template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) + template.format( + x86_wheel=x86_wheel, + x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), + arm_wheel=arm_wheel, + arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), + ) ) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml deleted file mode 100644 index 56ec933c9cc0e..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# For vllm script, with -t option (tensor parallel size). -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 -model_name: "HandH1998/QQQ-Llama-3-8b-g128" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.419 - - name: "exact_match,flexible-extract" - value: 0.416 -limit: 1000 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 27a1a9a82bd35..37eeac85c933b 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh index a67fc89d54e60..897f84d1e360d 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh @@ -2,7 +2,7 @@ # We can use this script to compute baseline accuracy on GSM for transformers. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index b98d42aa7b822..792f355c47a51 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -3,7 +3,7 @@ # We use this for fp8, which HF does not support. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index b39f9899a8f28..e6f5c8b60f459 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -141,7 +141,7 @@ When run, benchmark script generates results under `benchmark/results` folder, a `compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT. If only one benchmark_results.json is passed, `compare-json-results.py` compares different TP and PP configurations in the benchmark_results.json instead. -Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output lenght, max concurrency and qps. +Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output length, max concurrency and qps. `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json` | | Model | Dataset Name | Input Len | Output Len | # of max concurrency | qps | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio | diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 8afde017d383e..37e2980eea974 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ - SGLang: `lmsysorg/sglang:v0.3.2-cu121` - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` - - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* + - *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.* - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. - Hardware - 8x Nvidia A100 GPUs diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py index 12c4ba6aa69a6..50431d0cd4c5e 100644 --- a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py +++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py @@ -3,44 +3,129 @@ import argparse import json import os +from importlib import util import pandas as pd +plotly_found = util.find_spec("plotly.express") is not None + def compare_data_columns( files, name_column, data_column, info_cols, drop_column, debug=False ): - print("\ncompare_data_column: " + data_column) + """ + Align concatenation by keys derived from info_cols instead of row order. + - Pick one canonical key list: subset of info_cols present in ALL files. + - For each file: set index to those keys, aggregate duplicates + - (mean for metric, first for names). + - Concat along axis=1 (indexes align), then reset_index so callers can + - group by columns. + - If --debug, add a _name column per file. + """ + print("\ncompare_data_column:", data_column) + frames = [] raw_data_cols = [] compare_frames = [] + + # 1) choose a canonical key list from info_cols that exists in ALL files + cols_per_file = [] + for f in files: + try: + df_tmp = pd.read_json(f, orient="records") + except Exception as err: + raise ValueError(f"Failed to read {f}") from err + cols_per_file.append(set(df_tmp.columns)) + + key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)] + if not key_cols: + # soft fallback: use any info_cols present in the first file + key_cols = [c for c in info_cols if c in list(cols_per_file[0])] + if not key_cols: + raise ValueError( + "No common key columns found from info_cols across the input files." + ) + + # 2) build a single "meta" block (keys as columns) once, aligned by the key index + meta_added = False + for file in files: - data_df = pd.read_json(file) - serving_df = data_df.dropna(subset=[drop_column], ignore_index=True) - # Show all info columns in the first couple columns - if not frames: - for col in info_cols: - if col not in serving_df.columns: - print(f"Skipping missing column: {col}") - continue - frames.append(serving_df[col]) - # only show test name under debug mode - if debug is True: - serving_df = serving_df.rename(columns={name_column: file + "_name"}) - frames.append(serving_df[file + "_name"]) + df = pd.read_json(file, orient="records") - file = "/".join(file.split("/")[:-1]) - serving_df = serving_df.rename(columns={data_column: file}) - frames.append(serving_df[file]) - raw_data_cols.append(file) - compare_frames.append(serving_df[file]) + # Keep rows that actually have the compared metric (same as original behavior) + if drop_column in df.columns: + df = df.dropna(subset=[drop_column], ignore_index=True) + + # Stabilize numeric key columns (harmless if missing) + for c in ( + "Input Len", + "Output Len", + "TP Size", + "PP Size", + "# of max concurrency.", + "qps", + ): + if c in df.columns: + df[c] = pd.to_numeric(df[c], errors="coerce") + + # Ensure all key columns exist + for c in key_cols: + if c not in df.columns: + df[c] = pd.NA + + # Set index = key_cols and aggregate duplicates → unique MultiIndex + df_idx = df.set_index(key_cols, drop=False) + + # meta (key columns), unique per key + meta = df_idx[key_cols] + if not meta.index.is_unique: + meta = meta.groupby(level=key_cols, dropna=False).first() + + # metric series for this file, aggregated to one row per key + file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file) + s = df_idx[data_column] + if not s.index.is_unique: + s = s.groupby(level=key_cols, dropna=False).mean() + s.name = file_label # column label like original + + # add meta once (from first file) so keys are the leftmost columns + if not meta_added: + frames.append(meta) + meta_added = True + + # (NEW) debug: aligned test-name column per file + if debug and name_column in df_idx.columns: + name_s = df_idx[name_column] + if not name_s.index.is_unique: + name_s = name_s.groupby(level=key_cols, dropna=False).first() + name_s.name = f"{file_label}_name" + frames.append(name_s) + + frames.append(s) + raw_data_cols.append(file_label) + compare_frames.append(s) + + # Generalize ratio: for any file N>=2, add ratio (fileN / file1) if len(compare_frames) >= 2: - # Compare numbers among two files - ratio_df = compare_frames[1] / compare_frames[0] - frames.append(ratio_df) - compare_frames.pop(1) + base = compare_frames[0] + current = compare_frames[-1] + ratio = current / base + ratio = ratio.mask(base == 0) # avoid inf when baseline is 0 + ratio.name = f"Ratio 1 vs {len(compare_frames)}" + frames.append(ratio) + # 4) concat on columns with aligned MultiIndex; + # then reset_index to return keys as columns concat_df = pd.concat(frames, axis=1) + concat_df = concat_df.reset_index(drop=True).reset_index() + if "index" in concat_df.columns: + concat_df = concat_df.drop(columns=["index"]) + + # Ensure key/info columns appear first (in your info_cols order) + front = [c for c in info_cols if c in concat_df.columns] + rest = [c for c in concat_df.columns if c not in front] + concat_df = concat_df[front + rest] + print(raw_data_cols) return concat_df, raw_data_cols @@ -67,6 +152,15 @@ def split_json_by_tp_pp( df = pd.DataFrame(data) + # Keep only "serving" tests + name_col = next( + (c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None + ) + if name_col: + df = df[ + df[name_col].astype(str).str.contains(r"serving", case=False, na=False) + ].copy() + # Handle alias column names rename_map = { "tp_size": "TP Size", @@ -181,7 +275,6 @@ if __name__ == "__main__": f"Expected subset: {filtered_info_cols}, " f"but DataFrame has: {list(output_df.columns)}" ) - output_df_sorted = output_df.sort_values(by=existing_group_cols) output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False) for name, group in output_groups: @@ -189,8 +282,7 @@ if __name__ == "__main__": text_file.write(html_msgs_for_data_cols[i]) text_file.write(html) - if plot is True: - import pandas as pd + if plot and plotly_found: import plotly.express as px df = group[raw_data_cols] diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 496ee6083abde..77047636bb951 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -4,7 +4,6 @@ import argparse import json import os -import re import shlex from importlib import util from pathlib import Path @@ -12,6 +11,7 @@ from typing import Any import pandas as pd import psutil +import regex as re from tabulate import tabulate # latency results and the keys that will be printed into markdown diff --git a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh index 06d7b5ed484da..a00de940cbbb8 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh @@ -382,7 +382,7 @@ run_genai_perf_tests() { client_command="genai-perf profile \ -m $model \ --service-kind openai \ - --backend vllm \ + --backend "$backend" \ --endpoint-type chat \ --streaming \ --url localhost:$port \ diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 6314afd652340..f96c38bf57db7 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,4 +1,20 @@ steps: + # aarch64 + CUDA builds + - label: "Build arm64 wheel - CUDA 12.8" + id: build-wheel-arm64-cuda-12-8 + agents: + queue: arm64_cpu_queue_postmerge + commands: + # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: + # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/scripts/upload-wheels.sh" + env: + DOCKER_BUILDKIT: "1" + + # x86 + CUDA builds - label: "Build wheel - CUDA 12.8" id: build-wheel-cuda-12-8 agents: @@ -11,7 +27,12 @@ steps: env: DOCKER_BUILDKIT: "1" + - block: "Build CUDA 12.6 wheel" + key: block-build-cu126-wheel + depends_on: ~ + - label: "Build wheel - CUDA 12.6" + depends_on: block-build-cu126-wheel id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge @@ -52,7 +73,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - label: "Annotate release workflow" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 5e5a532cb57d5..df0bae0c9cbff 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -121,7 +121,6 @@ fi if [[ $commands == *" kernels/quantization"* ]]; then commands="${commands} \ --ignore=kernels/quantization/test_int8_quant.py \ - --ignore=kernels/quantization/test_aqlm.py \ --ignore=kernels/quantization/test_machete_mm.py \ --ignore=kernels/quantization/test_block_fp8.py \ --ignore=kernels/quantization/test_block_int8.py \ diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 57a7bc4e5f5df..9dec9f8e9eb32 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -46,6 +46,11 @@ function cpu_tests() { set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + # Run kernel tests + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -v -s tests/kernels/test_onednn.py" + # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e @@ -99,4 +104,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh index b571618f48c2b..1073a4ee30afa 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh @@ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR" echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ - && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ + && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off hf-transfer echo "--- Python dependencies installed ---" export VLLM_USE_V1=1 diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index d55a786e41e8b..505664f3aecd0 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR" echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ - && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ + && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off hf-transfer echo "--- Python dependencies installed ---" export VLLM_USE_V1=1 diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index deb61a9bafab6..73f3e63fbf5f6 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -23,10 +23,15 @@ docker run \ --device /dev/dri \ -v /dev/dri/by-path:/dev/dri/by-path \ --entrypoint="" \ + -e "HF_TOKEN=${HF_TOKEN}" \ + -e "ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK}" \ --name "${container_name}" \ "${image_name}" \ - sh -c ' + bash -c ' + set -e + echo $ZE_AFFINITY_MASK VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp cd tests @@ -35,8 +40,8 @@ docker run \ pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py - pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py + pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py + pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py pytest -v -s v1/test_serial_utils.py pytest -v -s v1/test_utils.py pytest -v -s v1/test_metrics_reader.py diff --git a/.buildkite/scripts/tpu/cleanup_docker.sh b/.buildkite/scripts/tpu/cleanup_docker.sh index 209d9c4341cdd..740d81fb39bb0 100755 --- a/.buildkite/scripts/tpu/cleanup_docker.sh +++ b/.buildkite/scripts/tpu/cleanup_docker.sh @@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then # Remove dangling images (those that are not tagged and not used by any container) docker image prune -f # Remove unused volumes / force the system prune for old images as well. - docker volume prune -f && docker system prune --force --filter "until=72h" --all + docker volume prune -f && docker system prune --force --filter "until=24h" --all echo "Docker images and volumes cleanup completed." else echo "Disk usage is below $threshold%. No cleanup needed." diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 037897e53dbef..745f285c008ad 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -14,8 +14,19 @@ fi # Get the single wheel file wheel="${wheel_files[0]}" -# Rename 'linux' to 'manylinux1' in the wheel filename -new_wheel="${wheel/linux/manylinux1}" +# Detect architecture and rename 'linux' to appropriate manylinux version +arch=$(uname -m) +if [[ $arch == "x86_64" ]]; then + manylinux_version="manylinux1" +elif [[ $arch == "aarch64" ]]; then + manylinux_version="manylinux2014" +else + echo "Warning: Unknown architecture $arch, using manylinux1 as default" + manylinux_version="manylinux1" +fi + +# Rename 'linux' to the appropriate manylinux version in the wheel filename +new_wheel="${wheel/linux/$manylinux_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 04d7cdc3d8854..0d3b7a294d963 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -88,15 +88,6 @@ steps: - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py -- label: Chunked Prefill Test - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/basic_correctness/test_chunked_prefill - commands: - - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - label: Core Test # 10min mirror_hardwares: [amdexperimental] fast_check: true @@ -135,7 +126,8 @@ steps: - tests/entrypoints/test_chat_utils commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ + - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py - pytest -v -s entrypoints/test_chat_utils.py - label: Distributed Tests (4 GPUs) # 10min @@ -252,7 +244,9 @@ steps: - pytest -v -s v1/core - pytest -v -s v1/engine - pytest -v -s v1/entrypoints + - pytest -v -s v1/executor - pytest -v -s v1/sample + - pytest -v -s v1/logits_processors - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode @@ -294,15 +288,6 @@ steps: - python3 offline_inference/basic/score.py - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 -- label: Prefix Caching Test # 9min - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/prefix_caching - commands: - - pytest -v -s prefix_caching - - - label: Platform Tests (CUDA) mirror_hardwares: [amdexperimental] source_file_dependencies: @@ -344,6 +329,7 @@ steps: - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py + - pytest -v -s compile/test_decorator.py - label: PyTorch Fullgraph Smoke Test # 9min mirror_hardwares: [amdexperimental] @@ -357,6 +343,7 @@ steps: - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py - pytest -v -s compile/piecewise/test_full_cudagraph.py + - pytest -v -s compile/piecewise/test_multiple_graphs.py - label: PyTorch Fullgraph Test # 18min mirror_hardwares: [amdexperimental] @@ -399,9 +386,11 @@ steps: - label: Kernels MoE Test %N mirror_hardwares: [amdexperimental] source_file_dependencies: + - csrc/quantization/cutlass_w8a8/moe/ - csrc/moe/ - tests/kernels/moe - vllm/model_executor/layers/fused_moe/ + - vllm/distributed/device_communicators/ commands: - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 @@ -466,13 +455,11 @@ steps: - label: LM Eval Small Models # 53min mirror_hardwares: [amdexperimental] - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization commands: - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness mirror_hardwares: [amdexperimental] @@ -560,6 +547,15 @@ steps: commands: - pytest -v -s models/language/pooling -m 'not core_model' +- label: Multi-Modal Processor Test + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + - pytest -v -s models/multimodal/processing/test_tensor_schema.py + - label: Multi-Modal Models Test (Standard) mirror_hardwares: [amdexperimental] torch_nightly: true @@ -569,9 +565,7 @@ steps: commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - - pytest -v -s models/multimodal/processing - - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model - - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn" + - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 @@ -582,7 +576,7 @@ steps: - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' + - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing - label: Multi-Modal Models Test (Extended) 2 mirror_hardwares: [amdexperimental] @@ -645,8 +639,10 @@ steps: - vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/fusion.py + - vllm/compilation/fusion_attn.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -659,10 +655,14 @@ steps: - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py + - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern + - pytest -v -s tests/kernels/moe/test_flashinfer.py ##### 1 GPU test ##### ##### multi gpus test ##### @@ -845,3 +845,10 @@ steps: commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +- label: Qwen MoE EP Test # optional + gpu: h200 + optional: true + num_gpus: 2 + commands: + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b0dd5e99d4c72..c087fd555c661 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,6 +10,7 @@ /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 +/vllm/model_executor/layers/mamba @tdoublep /vllm/multimodal @DarkLight1337 @ywang96 /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee @@ -25,11 +26,11 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat /vllm/v1/structured_output @mgoin @russellb @aarnphm +/vllm/v1/attention/backends/triton_attn.py @tdoublep # Test ownership /.buildkite/lm-eval-harness @mgoin @simon-mo /tests/async_engine @njhill @robertgshaw2-redhat @simon-mo -/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao @@ -44,6 +45,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/v1/structured_output @mgoin @russellb @aarnphm /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee +/tests/models/language/generation/test_hybrid.py @tdoublep # Docs /docs @hmellor @@ -72,3 +74,15 @@ mkdocs.yaml @hmellor /vllm/model_executor/models/pixtral*.py @patrickvonplaten /vllm/transformers_utils/configs/mistral.py @patrickvonplaten /vllm/transformers_utils/tokenizers/mistral.py @patrickvonplaten + +# Kernels +/vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep +/vllm/attention/ops/triton_unified_attention.py @tdoublep + +# ROCm related: specify owner with write access to notify AMD folks for careful code review +/docker/Dockerfile.rocm* @gshtras +/vllm/v1/attention/backends/rocm*.py @gshtras +/vllm/v1/attention/backends/mla/rocm*.py @gshtras +/vllm/attention/ops/rocm*.py @gshtras +/vllm/model_executor/layers/fused_moe/rocm*.py @gshtras + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1b30c1292df85..8043df65d5585 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -7,8 +7,6 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT ## Test Result -## (Optional) Documentation Update - ---
Essential Elements of an Effective PR Description Checklist @@ -17,6 +15,7 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT - [ ] The test plan, such as providing test command. - [ ] The test results, such as pasting the results comparison before and after, or e2e results - [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model. +- [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0).
**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml new file mode 100644 index 0000000000000..6401d6586cc3d --- /dev/null +++ b/.github/workflows/issue_autolabel.yml @@ -0,0 +1,305 @@ +name: Label issues based on keywords +on: + issues: + types: [opened, edited, reopened] +permissions: + issues: write # needed so the workflow can add labels + contents: read +concurrency: + group: issue-labeler-${{ github.event.issue.number }} + cancel-in-progress: true +jobs: + add-labels: + runs-on: ubuntu-latest + steps: + - name: Label issues based on keywords + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + // Configuration: Add new labels and keywords here + const labelConfig = { + rocm: { + // Keyword search - matches whole words only (with word boundaries) + keywords: [ + { + term: "composable kernel", + searchIn: "both" + }, + { + term: "rccl", + searchIn: "body" // only search in body + }, + { + term: "migraphx", + searchIn: "title" // only search in title + }, + { + term: "hipgraph", + searchIn: "both" + }, + { + term: "ROCm System Management Interface", + searchIn: "body" + }, + ], + + // Substring search - matches anywhere in text (partial matches) + substrings: [ + { + term: "VLLM_ROCM_", + searchIn: "both" + }, + { + term: "rocm", + searchIn: "title" + }, + { + term: "amd", + searchIn: "title" + }, + { + term: "hip-", + searchIn: "both" + }, + { + term: "gfx", + searchIn: "both" + }, + { + term: "cdna", + searchIn: "both" + }, + { + term: "rdna", + searchIn: "both" + }, + { + term: "torch_hip", + searchIn: "body" // only in body + }, + { + term: "_hip", + searchIn: "both" + }, + { + term: "hip_", + searchIn: "both" + }, + + // ROCm tools and libraries + { + term: "hipify", + searchIn: "both" + }, + ], + + // Regex patterns - for complex pattern matching + regexPatterns: [ + { + pattern: "\\bmi\\d{3}[a-z]*\\b", + description: "AMD GPU names (mi + 3 digits + optional letters)", + flags: "gi", + searchIn: "both" // "title", "body", or "both" + } + ], + }, + }; + + // Helper function to create regex based on search type + function createSearchRegex(term, type) { + // Escape special regex characters in the term + const escapedTerm = term.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + + switch (type) { + case 'keyword': + // Word boundary search - matches whole words only + return new RegExp(`\\b${escapedTerm}\\b`, "gi"); + case 'substring': + // Substring search - matches anywhere in the text + return new RegExp(escapedTerm, "gi"); + default: + throw new Error(`Unknown search type: ${type}`); + } + } + + // Helper function to find matching terms in text with line information + function findMatchingTermsWithLines(text, searchTerms = [], searchType = 'keyword', searchLocation = '') { + const matches = []; + const lines = text.split('\n'); + + for (const termConfig of searchTerms) { + let regex; + let term, searchIn, pattern, description, flags; + + // Handle different input formats (string or object) + if (typeof termConfig === 'string') { + term = termConfig; + searchIn = 'both'; // default + } else { + term = termConfig.term; + searchIn = termConfig.searchIn || 'both'; + pattern = termConfig.pattern; + description = termConfig.description; + flags = termConfig.flags; + } + + // Skip if this term shouldn't be searched in the current location + if (searchIn !== 'both' && searchIn !== searchLocation) { + continue; + } + + // Create appropriate regex + if (searchType === 'regex') { + regex = new RegExp(pattern, flags || "gi"); + } else { + regex = createSearchRegex(term, searchType); + } + + const termMatches = []; + + // Check each line for matches + lines.forEach((line, lineIndex) => { + const lineMatches = line.match(regex); + if (lineMatches) { + lineMatches.forEach(match => { + termMatches.push({ + match: match, + lineNumber: lineIndex + 1, + lineContent: line.trim(), + searchType: searchType, + searchLocation: searchLocation, + originalTerm: term || pattern, + description: description, + // Show context around the match in the line + context: line.length > 100 ? + line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30), + line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...' + : line.trim() + }); + }); + } + }); + + if (termMatches.length > 0) { + matches.push({ + term: term || (description || pattern), + searchType: searchType, + searchLocation: searchLocation, + searchIn: searchIn, + pattern: pattern, + matches: termMatches, + count: termMatches.length + }); + } + } + + return matches; + } + + // Helper function to check if label should be added + async function processLabel(labelName, config) { + const body = context.payload.issue.body || ""; + const title = context.payload.issue.title || ""; + + core.notice(`Processing label: ${labelName}`); + core.notice(`Issue Title: "${title}"`); + core.notice(`Issue Body length: ${body.length} characters`); + + let shouldAddLabel = false; + let allMatches = []; + let reason = ''; + + const keywords = config.keywords || []; + const substrings = config.substrings || []; + const regexPatterns = config.regexPatterns || []; + + core.notice(`Searching with ${keywords.length} keywords, ${substrings.length} substrings, and ${regexPatterns.length} regex patterns`); + + // Search in title + if (title.trim()) { + core.notice(`Searching in title: "${title}"`); + + const titleKeywordMatches = findMatchingTermsWithLines(title, keywords, 'keyword', 'title'); + const titleSubstringMatches = findMatchingTermsWithLines(title, substrings, 'substring', 'title'); + const titleRegexMatches = findMatchingTermsWithLines(title, regexPatterns, 'regex', 'title'); + + allMatches.push(...titleKeywordMatches, ...titleSubstringMatches, ...titleRegexMatches); + } + + // Search in body + if (body.trim()) { + core.notice(`Searching in body (${body.length} characters)`); + + const bodyKeywordMatches = findMatchingTermsWithLines(body, keywords, 'keyword', 'body'); + const bodySubstringMatches = findMatchingTermsWithLines(body, substrings, 'substring', 'body'); + const bodyRegexMatches = findMatchingTermsWithLines(body, regexPatterns, 'regex', 'body'); + + allMatches.push(...bodyKeywordMatches, ...bodySubstringMatches, ...bodyRegexMatches); + } + + if (allMatches.length > 0) { + core.notice(`Found ${allMatches.length} matching term(s):`); + + for (const termMatch of allMatches) { + const locationText = termMatch.searchLocation === 'title' ? 'title' : 'body'; + const searchInText = termMatch.searchIn === 'both' ? 'both' : termMatch.searchIn; + + if (termMatch.searchType === 'regex') { + core.notice(` 📍 Regex: "${termMatch.term}" (pattern: ${termMatch.pattern}) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); + } else { + core.notice(` 📍 Term: "${termMatch.term}" (${termMatch.searchType} search) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); + } + + // Show details for each match + termMatch.matches.forEach((match, index) => { + core.notice(` ${index + 1}. Line ${match.lineNumber} in ${match.searchLocation}: "${match.match}" [${match.searchType}]`); + if (match.description) { + core.notice(` Description: ${match.description}`); + } + core.notice(` Context: ${match.context}`); + if (match.lineContent !== match.context) { + core.notice(` Full line: ${match.lineContent}`); + } + }); + } + + shouldAddLabel = true; + const totalMatches = allMatches.reduce((sum, t) => sum + t.count, 0); + const titleMatches = allMatches.filter(t => t.searchLocation === 'title').reduce((sum, t) => sum + t.count, 0); + const bodyMatches = allMatches.filter(t => t.searchLocation === 'body').reduce((sum, t) => sum + t.count, 0); + const keywordMatches = allMatches.filter(t => t.searchType === 'keyword').reduce((sum, t) => sum + t.count, 0); + const substringMatches = allMatches.filter(t => t.searchType === 'substring').reduce((sum, t) => sum + t.count, 0); + const regexMatches = allMatches.filter(t => t.searchType === 'regex').reduce((sum, t) => sum + t.count, 0); + + reason = `Found ${totalMatches} total matches (${titleMatches} in title, ${bodyMatches} in body) - ${keywordMatches} keyword matches, ${substringMatches} substring matches, ${regexMatches} regex matches`; + } + + core.notice(`Final decision: ${shouldAddLabel ? 'ADD LABEL' : 'DO NOT ADD LABEL'}`); + core.notice(`Reason: ${reason || 'No matching terms found'}`); + + if (shouldAddLabel) { + const existingLabels = context.payload.issue.labels.map(l => l.name); + if (!existingLabels.includes(labelName)) { + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: [labelName], + }); + core.notice(`Label "${labelName}" added. ${reason}`); + return true; + } + core.notice(`Label "${labelName}" already present.`); + return false; + } + + core.notice(`No matching terms found for label "${labelName}".`); + return false; + } + + // Process all configured labels + const processLabels = Object.entries(labelConfig) + .map(([labelName, config]) => processLabel(labelName, config)); + const labelsAdded = await Promise.all(processLabels); + const numLabelsAdded = labelsAdded.reduce((x, y) => x + y, 0); + core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`); \ No newline at end of file diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml deleted file mode 100644 index 2b1086b7faf43..0000000000000 --- a/.github/workflows/lint-and-deploy.yaml +++ /dev/null @@ -1,89 +0,0 @@ -name: Lint and Deploy Charts - -on: pull_request - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - lint-and-deploy: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - - name: Set up Helm - uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0 - with: - version: v3.14.4 - - #Python is required because ct lint runs Yamale and yamllint which require Python. - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: '3.13' - - - name: Set up chart-testing - uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0 - with: - version: v3.10.1 - - - name: Run chart-testing (lint) - run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm - - - name: Setup minio - run: | - docker network create vllm-net - docker run -d -p 9000:9000 --name minio --net vllm-net \ - -e "MINIO_ACCESS_KEY=minioadmin" \ - -e "MINIO_SECRET_KEY=minioadmin" \ - -v /tmp/data:/data \ - -v /tmp/config:/root/.minio \ - minio/minio server /data - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - export AWS_EC2_METADATA_DISABLED=true - mkdir opt-125m - cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd .. - aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket - aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive - - - name: Create kind cluster - uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0 - - - name: Build the Docker image vllm cpu - run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env . - - - name: Configuration of docker images, network and namespace for the kind cluster - run: | - docker pull amazon/aws-cli:2.6.4 - kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing - kind load docker-image vllm-cpu-env:latest --name chart-testing - docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")" - kubectl create ns ns-vllm - - - name: Run chart-testing (install) - run: | - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & - helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - - - name: curl test - run: | - kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & - sleep 10 - CODE="$(curl -v -f --location http://localhost:8001/v1/completions \ - --header "Content-Type: application/json" \ - --data '{ - "model": "opt-125m", - "prompt": "San Francisco is a", - "max_tokens": 7, - "temperature": 0 - }'):$CODE" - echo "$CODE" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index bfd02879965ee..0000000000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,111 +0,0 @@ -# This workflow will upload a Python Package to Release asset -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions - -name: Create Release - -on: - push: - tags: - - v* - -# Needed to create release and upload assets -permissions: - contents: write - -jobs: - release: - # Retrieve tag and create release - name: Create Release - runs-on: ubuntu-latest - outputs: - upload_url: ${{ steps.create_release.outputs.upload_url }} - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Extract branch info - shell: bash - run: | - echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV" - - - name: Create Release - id: create_release - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 - env: - RELEASE_TAG: ${{ env.release_tag }} - with: - github-token: "${{ secrets.GITHUB_TOKEN }}" - script: | - const script = require('.github/workflows/scripts/create_release.js') - await script(github, context, core) - - # NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow. - # wheel: - # name: Build Wheel - # runs-on: ${{ matrix.os }} - # needs: release - - # strategy: - # fail-fast: false - # matrix: - # os: ['ubuntu-20.04'] - # python-version: ['3.9', '3.10', '3.11', '3.12'] - # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt. - # cuda-version: ['11.8', '12.1'] - - # steps: - # - name: Checkout - # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - # - name: Setup ccache - # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 - # with: - # create-symlink: true - # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - - # - name: Set up Linux Env - # if: ${{ runner.os == 'Linux' }} - # run: | - # bash -x .github/workflows/scripts/env.sh - - # - name: Set up Python - # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - # with: - # python-version: ${{ matrix.python-version }} - - # - name: Install CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} - - # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} - - # - name: Build wheel - # shell: bash - # env: - # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size - # run: | - # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} - # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) - # asset_name=${wheel_name//"linux"/"manylinux1"} - # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" - # echo "asset_name=${asset_name}" >> "$GITHUB_ENV" - - # - name: Upload Release Asset - # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 - # env: - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # with: - # upload_url: ${{ needs.release.outputs.upload_url }} - # asset_path: ./dist/${{ env.wheel_name }} - # asset_name: ${{ env.asset_name }} - # asset_content_type: application/* - - # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested - # - name: Publish package - # uses: pypa/gh-action-pypi-publish@release/v1.8 - # with: - # repository-url: https://test.pypi.org/legacy/ - # password: ${{ secrets.PYPI_API_TOKEN }} - # skip-existing: true diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 16ae1aadb96be..1ee605dc7bb0d 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -12,16 +12,43 @@ jobs: uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 with: script: | - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + - '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + - 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' + - 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + - 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + - '🚀' - }) + try { + // Get the PR author + const prAuthor = context.payload.pull_request.user.login; + + // Check if this is the author's first PR in this repository + // Use GitHub's search API to find all PRs by this author + const { data: searchResults } = await github.rest.search.issuesAndPullRequests({ + q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`, + per_page: 100 + }); + + const authorPRCount = searchResults.total_count; + + console.log(`Found ${authorPRCount} PRs by ${prAuthor}`); + + // Only post comment if this is the first PR (only one PR by this author) + if (authorPRCount === 1) { + console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + + '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + + 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' + + 'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' + + 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + + 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + + 'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' + + '🚀' + }); + } else { + console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`); + } + } catch (error) { + console.error('Error checking PR history or posting comment:', error); + // Don't fail the workflow, just log the error + } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 471fb2ad376ac..ac500440966fe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,7 +30,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") +set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") @@ -286,7 +286,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC - "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" @@ -359,9 +358,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) set(MARLIN_SRCS - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") @@ -754,6 +751,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "found in CUDA target architectures") endif() endif() + + # Only build W4A8 kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + # if CUDA endif endif() @@ -794,7 +818,9 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/moe_wna16.cu" + "csrc/moe/grouped_topk_kernels.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/README.md b/README.md index fd8b02ac1f781..8812aac4ea266 100644 --- a/README.md +++ b/README.md @@ -18,14 +18,16 @@ Easy, fast, and cheap LLM serving for everyone *Latest News* 🔥 +- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). +- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). -- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
Previous News +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). diff --git a/SECURITY.md b/SECURITY.md index 414669fb3712e..d6319cdb1ac27 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -42,4 +42,9 @@ For certain security issues of CRITICAL, HIGH, or MODERATE severity level, we ma * If you wish to be added to the prenotification group, please send an email copying all the members of the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Each vendor contact will be analyzed on a case-by-case basis. +* Organizations and vendors who either ship or use vLLM, are eligible to join the prenotification group if they meet at least one of the following qualifications + * Substantial internal deployment leveraging the upstream vLLM project. + * Established internal security teams and comprehensive compliance measures. + * Active and consistent contributions to the upstream vLLM project. + * We may withdraw organizations from receiving future prenotifications if they release fixes or any other information about issues before they are public. Group membership may also change based on policy refinements for who may be included. diff --git a/benchmarks/README.md b/benchmarks/README.md index d6442a4fc3872..38072152b653b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -22,6 +22,25 @@ become available. ✅ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + + ShareGPT4V (Image) + ✅ + ✅ + + wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json +
+
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
+ wget http://images.cocodataset.org/zips/train2017.zip + + + + ShareGPT4Video (Video) + ✅ + ✅ + + git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video + + BurstGPT ✅ @@ -29,7 +48,7 @@ become available. wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv - Sonnet + Sonnet (deprecated) ✅ ✅ Local file: benchmarks/sonnet.txt @@ -40,6 +59,18 @@ become available. ✅ synthetic + + RandomMultiModal (Image/Video) + 🟡 + 🚧 + synthetic + + + Prefix Repetition + ✅ + ✅ + synthetic + HuggingFace-VisionArena ✅ @@ -177,6 +208,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -213,6 +245,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -227,6 +260,7 @@ vllm bench serve \ ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -581,6 +615,20 @@ python3 benchmarks/benchmark_prefix_caching.py \ --input-length-range 128:256 ``` +### Prefix Repetition Dataset + +```bash +vllm bench serve \ + --backend openai \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-name prefix_repetition \ + --num-prompts 100 \ + --prefix-repetition-prefix-len 512 \ + --prefix-repetition-suffix-len 128 \ + --prefix-repetition-num-prefixes 5 \ + --prefix-repetition-output-len 128 +``` +
## ⚡ Example - Request Prioritization Benchmark @@ -616,3 +664,139 @@ python3 benchmarks/benchmark_prioritization.py \ ``` + +## 👁️ Example - Multi-Modal Benchmark + +
+Show more + +
+ +Benchmark the performance of multi-modal requests in vLLM. + +### Images (ShareGPT4V) + +Start vLLM: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"image": 1}' \ + --allowed-local-media-path /path/to/sharegpt4v/images +``` + +Send requests with images: + +```bash +python benchmarks/benchmark_serving.py \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completion +``` + +### Videos (ShareGPT4Video) + +Start vLLM: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"video": 1}' \ + --allowed-local-media-path /path/to/sharegpt4video/videos +``` + +Send requests with videos: + +```bash +python benchmarks/benchmark_serving.py \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completion +``` + +### Synthetic Random Images (random-mm) + +Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. + +Notes: + +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. +- Video sampling is not yet implemented. + +Start the server (example): + +```bash +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --mm-processor-kwargs max_pixels=1003520 +``` + +Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. + +Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name random-mm \ + --num-prompts 100 \ + --max-concurrency 10 \ + --random-prefix-len 25 \ + --random-input-len 300 \ + --random-output-len 40 \ + --random-range-ratio 0.2 \ + --random-mm-base-items-per-request 2 \ + --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ + --request-rate inf \ + --ignore-eos \ + --seed 42 +``` + +The number of items per request can be controlled by passing multiple image buckets: + +```bash + --random-mm-base-items-per-request 2 \ + --random-mm-num-mm-items-range-ratio 0.5 \ + --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ + --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ +``` + +Flags specific to `random-mm`: + +- `--random-mm-base-items-per-request`: base number of multimodal items per request. +- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. +- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. +- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). + +Behavioral notes: + +- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. + +How sampling works: + +- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. +- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. +- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. +This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. +- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. + +
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 1559ca2d92841..ba7c733be0b25 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -34,6 +34,7 @@ class RequestFuncInput: multi_modal_content: Optional[dict | list[dict]] = None ignore_eos: bool = False language: Optional[str] = None + request_id: Optional[str] = None @dataclass @@ -71,6 +72,9 @@ async def async_request_tgi( "inputs": request_func_input.prompt, "parameters": params, } + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len if request_func_input.ignore_eos: @@ -82,7 +86,9 @@ async def async_request_tgi( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -145,6 +151,9 @@ async def async_request_trt_llm( } if request_func_input.ignore_eos: payload["min_length"] = request_func_input.output_len + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -152,7 +161,9 @@ async def async_request_trt_llm( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -211,6 +222,8 @@ async def async_request_deepspeed_mii( "top_p": 1.0, } headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -283,6 +296,8 @@ async def async_request_openai_completions( if request_func_input.extra_body: payload.update(request_func_input.extra_body) headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -395,6 +410,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -491,6 +508,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id # Send audio file def to_bytes(y, sr): diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index ea684f18a7421..2ea4f9ccaff2b 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -19,6 +19,7 @@ import logging import random from abc import ABC, abstractmethod from collections.abc import Mapping +from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO @@ -54,6 +55,7 @@ class SampleRequest: expected_output_len: int multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None lora_request: Optional[LoRARequest] = None + request_id: Optional[str] = None # ----------------------------------------------------------------------------- @@ -155,7 +157,10 @@ class BenchmarkDataset(ABC): @abstractmethod def sample( - self, tokenizer: PreTrainedTokenizerBase, num_requests: int + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -167,6 +172,7 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. + request_id_prefix (str) The prefix of request_id. Returns: list[SampleRequest]: A list of sample requests generated from the @@ -175,7 +181,10 @@ class BenchmarkDataset(ABC): raise NotImplementedError("sample must be implemented in subclasses.") def maybe_oversample_requests( - self, requests: list[SampleRequest], num_requests: int + self, + requests: list[SampleRequest], + num_requests: int, + request_id_prefix: str = "", ) -> None: """ Oversamples the list of requests if its size is less than the desired @@ -183,11 +192,18 @@ class BenchmarkDataset(ABC): Args: requests (List[SampleRequest]): The current list of sampled - requests. num_requests (int): The target number of requests. + requests. + num_requests (int): The target number of requests. + request_id_prefix (str) The prefix of the request ids. """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, k=num_requests - len(requests)) + additional = deepcopy( + random.choices(requests, k=num_requests - len(requests)) + ) + for i in range(len(additional)): + req = additional[i] + req.request_id = request_id_prefix + str(len(requests) + i) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) @@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]: ) +def process_video(video: Any) -> Mapping[str, Any]: + """ + Process a single video input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key + containing raw video data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(video, dict) and "bytes" in video: + video_bytes = video["bytes"] + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return { + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, + } + + if isinstance(video, str): + video_url = ( + video if video.startswith(("http://", "file://")) else f"file://{video}" + ) + return {"type": "video_url", "video_url": {"url": video_url}} + + raise ValueError( + f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 + ) + + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- @@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset): range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: # Enforce range_ratio < 1 @@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset): prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), + request_id=request_id_prefix + str(i), ) ) + return requests @@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: samples: list = [] + ind = 0 for entry in self.data: if len(samples) >= num_requests: break @@ -430,17 +486,26 @@ class ShareGPTDataset(BenchmarkDataset): skip_min_output_len_check=output_len is not None, ): continue + if image_path := entry.get("image"): + mm_content = process_image(image_path) + elif video_path := entry.get("video"): + mm_content = process_video(video_path) + else: + mm_content = None if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation(prompt, None) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, lora_request=lora_request, + multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(samples, num_requests) + ind += 1 + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -506,10 +571,11 @@ class CustomDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["prompt"] @@ -528,9 +594,12 @@ class CustomDataset(BenchmarkDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -572,6 +641,7 @@ class SonnetDataset(BenchmarkDataset): input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -597,6 +667,7 @@ class SonnetDataset(BenchmarkDataset): prefix_lines = self.data[:num_prefix_lines] samples = [] + ind = 0 while len(samples) < num_requests: extra_lines = random.choices( self.data, k=num_input_lines - num_prefix_lines @@ -607,14 +678,17 @@ class SonnetDataset(BenchmarkDataset): msg, add_generation_prompt=True, tokenize=False ) prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: samples.append( SampleRequest( prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 return samples @@ -666,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset): num_requests: int, max_loras: Optional[int] = None, lora_path: Optional[str] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: samples = [] @@ -687,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset): prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, + request_id=request_id_prefix + str(i), ) ) return samples @@ -746,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] dynamic_output = output_len is None + ind = 0 for item in filtered_data: if len(sampled_requests) >= num_requests: @@ -779,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -808,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) @@ -832,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -864,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - prompt = f"{item['input']}\n\n{item['instruction']} Just output \ - the code, do not include any explanation." + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) # apply template prompt = tokenizer.apply_chat_template( @@ -886,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -918,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["turns"][0] @@ -941,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -968,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] dynamic_output = output_len is None + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -994,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -1066,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset): "zed-industries/zeta": _format_zeta_prompt, } - def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + **kwargs, + ): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) if formatting_prompt_func is None: raise ValueError(f"Unsupported dataset path: {self.dataset_path}") samples = [] - for sample in self.data: + for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) samples.append( SampleRequest( @@ -1080,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset): expected_output_len=len( tokenizer(sample["expected_output"]).input_ids ), + request_id=request_id_prefix + str(i), ) ) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests) + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -1133,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: import librosa @@ -1142,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset): prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] skipped = 0 + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: break @@ -1160,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 if skipped: logger.warning( "%d samples discarded from dataset due to" @@ -1169,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index ae38caf7290b1..02f5f585c0c16 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -375,11 +375,12 @@ async def benchmark( rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) last_int_rps = current_int_rps - prompt, prompt_len, output_len, mm_content = ( + prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, request.prompt_len, request.expected_output_len, request.multi_modal_data, + request.request_id, ) req_model_id, req_model_name = model_id, model_name if lora_modules: @@ -397,6 +398,7 @@ async def benchmark( multi_modal_content=mm_content, ignore_eos=ignore_eos, extra_body=extra_body, + request_id=request_id, ) task = limited_request_func(request_func_input=request_func_input, pbar=pbar) tasks.append(asyncio.create_task(task)) @@ -665,6 +667,7 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "sonnet": @@ -678,6 +681,7 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=False, + request_id_prefix=args.request_id_prefix, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -690,6 +694,7 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=True, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "hf": @@ -751,6 +756,7 @@ def main(args: argparse.Namespace): num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, + request_id_prefix=args.request_id_prefix, ) else: @@ -762,10 +768,15 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, num_requests=args.num_prompts, output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, ), "burstgpt": lambda: BurstGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path - ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + ), "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -773,6 +784,7 @@ def main(args: argparse.Namespace): input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, ), } @@ -1118,6 +1130,13 @@ def create_argument_parser(): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) + parser.add_argument( + "--request-id-prefix", + type=str, + required=False, + default="benchmark-serving", + help="Specify the prefix of request id.", + ) # group for dataset specific arguments custom_group = parser.add_argument_group("custom dataset options") diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c51b579686529..6b24b8c8f3c67 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -96,7 +96,6 @@ def run_vllm( end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" - prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0].expected_output_len for request in requests: @@ -597,8 +596,8 @@ def validate_args(args): # https://github.com/vllm-project/vllm/issues/16222 if args.data_parallel_size > 1: raise ValueError( - "Data parallel is not supported in offline benchmark, \ - please use benchmark serving instead" + "Data parallel is not supported in offline benchmark, " + "please use benchmark serving instead" ) diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py deleted file mode 100644 index 42de062b08e42..0000000000000 --- a/benchmarks/kernels/benchmark_aqlm.py +++ /dev/null @@ -1,345 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import sys -from typing import Optional - -import torch -import torch.nn.functional as F - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.aqlm import ( - dequantize_weight, - generic_dequantize_gemm, - get_int_dtype, - optimized_dequantize_gemm, -) -from vllm.utils import FlexibleArgumentParser - -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -def torch_mult( - # [..., in_features] - input: torch.Tensor, - weights: torch.Tensor, - # [num_out_groups, 1, 1, 1] - scales: torch.Tensor, -) -> torch.Tensor: - output = F.linear(input, weights) - return output - - -def dequant_out_scale( - # [..., in_features] - input: torch.Tensor, - # [num_out_groups, num_in_groups, num_codebooks] - codes: torch.IntTensor, - # [num_codebooks, codebook_size, out_group_size, in_group_size] - codebooks: torch.Tensor, - # [num_out_groups, 1, 1, 1] - scales: torch.Tensor, - output_partition_sizes: torch.IntTensor, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - - if bias is None: - output = F.linear(input, weights, bias) - orig_shape = output.shape - flattened_output = output.view(-1, output.size(-1)) - f_scales = scales.view(-1, scales.shape[0]) - b_scales = f_scales.expand(flattened_output.shape[0], -1) - flattened_output *= b_scales - return flattened_output.view(orig_shape) - else: - b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) - weights *= b_scales - return F.linear(input, weights, bias) - - -def dequant_weight_scale( - # [..., in_features] - input: torch.Tensor, - # [num_out_groups, num_in_groups, num_codebooks] - codes: torch.IntTensor, - # [num_codebooks, codebook_size, out_group_size, in_group_size] - codebooks: torch.Tensor, - # [num_out_groups, 1, 1, 1] - scales: torch.Tensor, - output_partition_sizes: torch.IntTensor, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - - b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) - weights *= b_scales - return F.linear(input, weights, bias) - - -def dequant_no_scale( - # [..., in_features] - input: torch.Tensor, - # [num_out_groups, num_in_groups, num_codebooks] - codes: torch.IntTensor, - # [num_codebooks, codebook_size, out_group_size, in_group_size] - codebooks: torch.Tensor, - # [num_out_groups, 1, 1, 1] - scales: torch.Tensor, - output_partition_sizes: torch.IntTensor, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - - return F.linear(input, weights, bias) - - -# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against -# the generic pytorch version. -# Just visual comparison. -def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: - n = int(parts.sum().item()) - - device = torch.device("cuda:0") - - code_range = (1 << bits) // 2 - ingroups = 8 - - codes = torch.randint( - -code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device, - ) - - codebooks = torch.randn( - size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device, - ) - - count = 0 - for index in range(16): - for i in range(8): - for book in range(nbooks): - codebooks[book, index, 0, i] = count * (10**book) - count += 1 - - print("codes shape", codes.shape) - - for i in range(16): - for book in range(nbooks): - codes[0, i, book] = i - codes[0, -i, book] = i - - weights = dequantize_weight(codes, codebooks, None) - weights2 = ops.aqlm_dequant(codes, codebooks, parts) - - print("weights shape:", weights.shape) - print("weights2 shape:", weights2.shape) - - print("weights are:", weights) - print("weights2 are:", weights2) - - print("first 128 weights are", weights[0, 0:128].to(torch.int32)) - print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32)) - - print("last 128 weights are", weights[0, -128:]) - print("last 128 weights2 are:", weights2[0, -128:]) - - -def main(): - parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") - - # Add arguments - parser.add_argument( - "--nbooks", type=int, default=1, help="Number of codebooks (default: 1)" - ) - parser.add_argument( - "--bits", - type=int, - default=16, - help="Number of bits per code element (default: 16)", - ) - parser.add_argument( - "--test", - type=bool, - default=False, - help="Run the decompression/dequant tester rather than benchmarking " - "(default: False)", - ) - - # Parse the arguments - args = parser.parse_args() - - # Extract values - nbooks = args.nbooks - bits = args.bits - - if args.test: - dequant_test(4096, torch.tensor((4096,)), nbooks, bits) - return - - # Otherwise, benchmark. - methods = [ - ops.aqlm_gemm, - dequant_out_scale, - generic_dequantize_gemm, - optimized_dequantize_gemm, - dequant_weight_scale, - torch_mult, - dequant_no_scale, - ] - - filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv" - print(f"writing benchmarks to file {filename}") - with open(filename, "w") as f: - sys.stdout = f - - print("m | k | n | n parts", end="") - for method in methods: - print(f" | {method.__name__.replace('_', ' ')} (µs)", end="") - print("") - - # These are reasonable prefill sizes. - ksandpartions = ( - (4096, (4096, 4096, 4096)), - (4096, (4096,)), - (4096, (11008, 11008)), - (11008, (4096,)), - ) - - # reasonable ranges for m. - for m in [ - 1, - 2, - 4, - 8, - 10, - 12, - 14, - 16, - 24, - 32, - 48, - 52, - 56, - 64, - 96, - 112, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, - ]: - print(f"{m}", file=sys.__stdout__) - for ksp in ksandpartions: - run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods) - - sys.stdout = sys.__stdout__ - - -def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods): - # I didn't see visible improvements from increasing these, but feel free :) - num_warmup_trials = 1 - num_trials = 1 - - num_calls = 100 - - # warmup. - for method in methods: - for _ in range(num_warmup_trials): - run_timing( - num_calls=num_calls, - m=m, - k=k, - parts=parts, - nbooks=nbooks, - bits=bits, - method=method, - ) - - n = parts.sum().item() - print(f"{m} | {k} | {n} | {parts.tolist()}", end="") - - for method in methods: - best_time_us = 1e20 - for _ in range(num_trials): - kernel_dur_ms = run_timing( - num_calls=num_calls, - m=m, - k=k, - parts=parts, - nbooks=nbooks, - bits=bits, - method=method, - ) - - kernel_dur_us = 1000 * kernel_dur_ms - - if kernel_dur_us < best_time_us: - best_time_us = kernel_dur_us - - print(f" | {kernel_dur_us:.0f}", end="") - - print("") - - -def run_timing( - num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method -) -> float: - n = int(parts.sum().item()) - - device = torch.device("cuda:0") - - input = torch.randn((1, m, k), dtype=torch.float16, device=device) - - code_range = (1 << bits) // 2 - ingroups = 8 - - codes = torch.randint( - -code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device, - ) - - codebooks = torch.randn( - size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device, - ) - - scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) - - # for comparison to just a pytorch mult. - weights = torch.randn((n, k), dtype=torch.float16, device=device) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - - if method is torch_mult: - for i in range(num_calls): - torch_mult(input, weights, scales) - else: - for i in range(num_calls): - method(input, codes, codebooks, scales, parts, None) - - end_event.record() - end_event.synchronize() - - dur_ms = start_event.elapsed_time(end_event) / num_calls - return dur_ms - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 1d4e730f99ae9..a6b42406b5cb0 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -80,6 +80,11 @@ def bench_run( a, score, topk, renormalize=False ) + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + def run_triton_moe( a: torch.Tensor, w1: torch.Tensor, @@ -111,6 +116,10 @@ def bench_run( w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, @@ -125,6 +134,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -136,6 +149,10 @@ def bench_run( w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): @@ -150,6 +167,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -194,6 +215,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, ) @@ -231,6 +256,10 @@ def bench_run( "w1_scale": w1_scale, "w2_scale": w2_scale, "per_act_token": per_act_token, + "ab_strides1": ab_strides1, + "ab_strides2": ab_strides2, + "c_strides1": c_strides1, + "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -289,6 +318,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, per_act_token, @@ -297,7 +330,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 975d10f2e92ec..1b1c3b321cce4 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 - - if bt.w_ch_s is not None: - s_ch = bt.w_ch_s.to(torch.float32) - else: - s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) - - if bt.w_tok_s is not None: - s_tok = bt.w_tok_s.to(torch.float32) - else: - s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) - - fn = lambda: ops.marlin_qqq_gemm( - a=bt.a, - b_q_weight=w_q, - s_group=w_s, - s_tok=s_tok, - s_ch=s_ch, - workspace=workspace.scratch, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0], - ) + raise NotImplementedError("QQQ is not supported anymore") return fn @@ -305,6 +284,25 @@ def machete_create_bench_fn( ) +def cutlass_w4a8_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.cutlass_encode_and_reorder_int4b(w_q) + # expects fp8 scales + w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn)) + + return lambda: ops.cutlass_w4a8_mm( + a=bt.a, + b_q=w_q, + b_group_scales=w_s, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + maybe_schedule=schedule, + ) + + # impl # bench @@ -406,6 +404,20 @@ def bench( ) ) + # cutlass w4a8 + if types.act_type == torch.float8_e4m3fn and group_size == 128: + timers.append( + bench_fns( + label, + sub_label, + f"cutlass w4a8 ({name_type_string})", + [ + cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) + if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 13bf1be836f6a..752c2d0082167 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -3,6 +3,7 @@ import argparse import json +import os import time from contextlib import nullcontext from datetime import datetime @@ -429,7 +430,6 @@ class BenchmarkWorker: hidden_size, topk, dtype_str, - is_marlin=False, ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] @@ -542,6 +542,7 @@ def save_configs( use_fp8_w8a8: bool, use_int8_w8a16: bool, block_quant_shape: list[int], + save_dir: str, ) -> None: dtype_str = get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 @@ -552,7 +553,8 @@ def save_configs( filename = get_config_file_name( num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape ) - + os.makedirs(save_dir, exist_ok=True) + filename = os.path.join(save_dir, filename) print(f"Writing best config to {filename}...") with open(filename, "w") as f: json.dump(configs, f, indent=4) @@ -707,6 +709,7 @@ def main(args: argparse.Namespace): use_fp8_w8a8, use_int8_w8a16, block_quant_shape, + args.save_dir, ) end = time.time() print(f"Tuning took {end - start:.2f} seconds") @@ -748,6 +751,9 @@ if __name__ == "__main__": "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" ) parser.add_argument("--use-deep-gemm", action="store_true") + parser.add_argument( + "--save-dir", type=str, default="./", help="Directory to save tuned results" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, nargs="+", required=False) parser.add_argument("--tune", action="store_true") diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py new file mode 100644 index 0000000000000..0650cbf3cc18e --- /dev/null +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm, +) +from vllm.platforms import current_platform + + +def benchmark(E, T, H, G=128, runs=50): + current_platform.seed_everything(42) + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") + tokens_per_expert = torch.randint( + T // 2, T, size=(E,), dtype=torch.int32, device="cuda" + ) + + # Warmup + for _ in range(10): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + # Benchmark + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(runs): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + avg_time = (time.perf_counter() - start) / runs * 1000 + + # Calculate actual work done (only count valid tokens) + actual_tokens = tokens_per_expert.sum().item() + actual_elements = actual_tokens * H + + # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops + ops_per_element = 8 + total_ops = actual_elements * ops_per_element + gflops = total_ops / (avg_time / 1000) / 1e9 + + # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) + input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs + output_bytes = actual_tokens * H * 1 # H fp8 outputs + scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 + total_bytes = input_bytes + output_bytes + scale_bytes + memory_bw = total_bytes / (avg_time / 1000) / 1e9 + + return avg_time, gflops, memory_bw + + +configs = [ + (8, 32, 1024), + (16, 64, 2048), + (32, 128, 4096), + # DeepSeekV3 Configs + (256, 16, 7168), + (256, 32, 7168), + (256, 64, 7168), + (256, 128, 7168), + (256, 256, 7168), + (256, 512, 7168), + (256, 1024, 7168), +] + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") +print("-" * 50) + +for E, T, H in configs: + try: + time_ms, gflops, gbps = benchmark(E, T, H) + print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") + except Exception: + print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 77136edca45b5..603ce5ecf0d2c 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -3,16 +3,17 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +from vllm.utils import round_up -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,65 +27,106 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_decode( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") - device = "cuda" torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + NUM_BLOCKS = int(256000 / block_size) - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") - # For decode, batch_size is num_decode_token - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) - kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, _ = to_float8(ref_query) + else: + query = ref_query - max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) - max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size + kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_seq_len - block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 - ) + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, _ = to_float8(ref_kv_cache) + else: + kv_cache = ref_kv_cache - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) - output_trtllm = torch.empty(q.shape, dtype=dtype) + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) - # Benchmark TRT decode - def trt_decode(): - return flashinfer.decode.trtllm_batch_decode_with_kv_cache( - q, - kv_cache, - workspace_buffer, - block_tables, - kv_lens_tensor, - max_kv_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=True, + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + ) def time_fn(fn, warmup=10, trials=20): torch.cuda.synchronize() @@ -101,74 +143,72 @@ def benchmark_decode( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) - # TRT Decode - trt_mean, trt_std = time_fn(trt_decode) - - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size - if kv_last_page_len == 0: - kv_last_page_len = page_size - kv_last_page_lens.append(kv_last_page_len) - - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) - - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), - ) - - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - "NONE", - q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, - ) + o_scale = 1.0 + o_sf_scale = None + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) def baseline_decode(): - return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) + + def trtllm_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, + out=output_trtllm, + ) baseline_mean, baseline_std = time_fn(baseline_decode) + trtllm_mean, trtllm_std = time_fn(trtllm_decode) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}" f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -180,17 +220,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -209,45 +250,43 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), + ] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="fp8", - ) - all_results.append(result) + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_decode( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 67bd9aebbcca9..40903c6c3444f 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -3,16 +3,17 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +from vllm.utils import round_up -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,84 +27,100 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_prefill( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + max_q_len = max_kv_len = max_seq_len + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + NUM_BLOCKS = int(256000 / block_size) - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - - q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - q_lens[-1] = MAX_SEQ_LEN - max_q_len = max(q_lens) + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) + q_lens[-1] = max_q_len q_indptr = torch.cat( [ torch.tensor([0], dtype=torch.int32), - torch.cumsum( - torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), ] ) - q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) - kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] - kv_lens[-1] = MAX_SEQ_LEN - - seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size - block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn( + torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype ) + if q_quant_dtype == FP8_DTYPE: + query, _ = to_float8(ref_query) + else: + query = ref_query - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) + kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + # Always using 1.0 scale to reflect the real perf in benchmarking k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, _ = to_float8(ref_kv_cache) + else: + kv_cache = ref_kv_cache - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size + num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size + kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = page_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout @@ -115,12 +132,12 @@ def benchmark_prefill( kv_last_page_lens, num_qo_heads, num_kv_heads, - head_dim, - page_size, + head_size, + block_size, causal=True, sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache.dtype, + kv_data_type=dtype, ) def time_fn(fn, warmup=10, trials=20): @@ -138,52 +155,76 @@ def benchmark_prefill( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) + o_scale = 1.0 + o_sf_scale = None + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + def baseline_prefill(): return wrapper.run( - q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, ) - def trt_prefill(): + def trtllm_prefill(): return flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, + query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, - seq_lens=seq_lens_tensor, + seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) - trt_mean, trt_std = time_fn(trt_prefill) baseline_mean, baseline_std = time_fn(baseline_prefill) + trtllm_mean, trtllm_std = time_fn(trtllm_prefill) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" - f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}" + f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -195,17 +236,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -224,27 +266,42 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_prefill( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), + ] + + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_prefill( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 4fcdbadd65ecd..e648a91077fdb 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -11,8 +11,8 @@ from datetime import datetime from typing import Any import torch -import tqdm import triton +from tqdm import tqdm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( _w8a8_block_fp8_matmul, diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index a27f02394afbd..9a057990bda5f 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -95,4 +95,10 @@ WEIGHT_SHAPES = { ([2048, 2816], 1), ([1408, 2048], 0), ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 73728], 1), + ([36864, 12288], 0), + ], } diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md index ae0866ae60751..7adf97bcf5622 100644 --- a/benchmarks/multi_turn/README.md +++ b/benchmarks/multi_turn/README.md @@ -5,11 +5,13 @@ The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `re First start serving your model ```bash -export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ -vllm serve $MODEL_NAME --disable-log-requests +vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests ``` +The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface). + ## Synthetic Multi-Turn Conversations Download the following text file (used for generation of synthetic conversations) @@ -26,10 +28,10 @@ But you may use other text files if you prefer (using this specific file is not Then run the benchmarking script ```bash -export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ -python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \ ---num-clients 2 --max-active-conversations 6 +python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \ +--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6 ``` You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.). diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index 53c3207491d18..d23b7b6e4571d 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -825,9 +825,11 @@ def get_client_config( # Arguments for API requests chat_url = f"{args.url}/v1/chat/completions" + model_name = args.served_model_name if args.served_model_name else args.model + req_args = RequestArgs( chat_url=chat_url, - model=args.model, + model=model_name, stream=not args.no_stream, limit_min_tokens=args.limit_min_tokens, limit_max_tokens=args.limit_max_tokens, @@ -1247,9 +1249,19 @@ async def main() -> None: default=0, help="Seed for random number generators (default: 0)", ) + parser.add_argument( "-m", "--model", type=str, required=True, help="Path of the LLM model" ) + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) + parser.add_argument( "-u", "--url", diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index e0da46e2accaa..52bfd82c7fcfe 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,6 +1,7 @@ include(FetchContent) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -182,17 +183,17 @@ endif() # # Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) # Flag to enable ACL kernels for AARCH64 platforms -if ( VLLM_BUILD_ACL STREQUAL "ON") +if (VLLM_BUILD_ACL STREQUAL "ON") set(USE_ACL ON) else() set(USE_ACL OFF) endif() -if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.8.1 + GIT_TAG v3.9 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) @@ -204,7 +205,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) endif() set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") - endif() + endif() set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_BUILD_DOC "OFF") @@ -217,38 +218,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) -elseif(POWER10_FOUND) - FetchContent_Declare( - oneDNN - GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.7.2 - GIT_PROGRESS TRUE - GIT_SHALLOW TRUE + add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") + target_include_directories( + dnnl_ext + PUBLIC ${oneDNN_SOURCE_DIR}/include + PUBLIC ${oneDNN_BINARY_DIR}/include + PRIVATE ${oneDNN_SOURCE_DIR}/src ) - - set(ONEDNN_LIBRARY_TYPE "STATIC") - set(ONEDNN_BUILD_DOC "OFF") - set(ONEDNN_BUILD_EXAMPLES "OFF") - set(ONEDNN_BUILD_TESTS "OFF") - set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") - set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") - set(ONEDNN_BUILD_GRAPH "OFF") - set(ONEDNN_ENABLE_JIT_PROFILING "OFF") - set(ONEDNN_ENABLE_ITT_TASKS "OFF") - set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") - set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) - - set(DNNL_CPU_RUNTIME "OMP") - - FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) + target_link_libraries(dnnl_ext dnnl) + target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) + list(APPEND LIBS dnnl_ext) + set(USE_ONEDNN ON) +else() + set(USE_ONEDNN OFF) endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") @@ -275,7 +261,6 @@ set(VLLM_EXT_SRC if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) @@ -289,14 +274,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ${VLLM_EXT_SRC}) add_compile_definitions(-DCPU_CAPABILITY_AVX512) endif() -elseif(POWER10_FOUND) - set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" - ${VLLM_EXT_SRC}) endif() -if (ASIMD_FOUND) + +if(USE_ONEDNN) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" + "csrc/cpu/dnnl_kernels.cpp" ${VLLM_EXT_SRC}) endif() diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index ee6768bce26ca..02224cfe3ee81 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1 + GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu) + ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc/include) + ${flashmla_SOURCE_DIR}/csrc) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index d24d8e8e5e795..49defccbb1fa4 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 93cf5a08f421a3efd0c4a7e005ef8f742b578ce0 + GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 55e6596797010..a4a880f13cf7e 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param( } } +template +__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, + float alpha, float limit) { + // clamp gate: min=None, max=limit + const float gate_f = (float)gate; + const float clamped_gate = gate_f > limit ? limit : gate_f; + + // clamp up: min=-limit, max=limit + const float up_f = (float)up; + const float clamped_up = + up_f > limit ? limit : (up_f < -limit ? -limit : up_f); + + // glu = gate * sigmoid(gate * alpha) + const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha)); + const float glu = clamped_gate * sigmoid_val; + + // (up + 1) * glu + return (T)((clamped_up + 1.0f) * glu); +} + +template +__global__ void swigluoai_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d, const float alpha, const float limit) { + const int64_t token_idx = blockIdx.x; + // TODO: Vectorize loads and stores. + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + // gate = x[..., ::2] (even indices) + const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]); + // up = x[..., 1::2] (odd indices) + const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]); + + out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); + } +} + } // namespace vllm #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ @@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param( PARAM); \ }); +#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ + vllm::swigluoai_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, ALPHA, \ + LIMIT); \ + }); + void fatrelu_and_mul(torch::Tensor& out, // [..., d], torch::Tensor& input, // [..., 2 * d] double threshold) { LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); } +void swigluoai_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + double alpha, double limit) { + LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit); +} namespace vllm { // Element-wise activation kernel template. diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index e0e95d06290df..6dd6f269f3dc9 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options( // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will // perform worse with larger context length and smaller batch sizes. - num_kv_splits, // split_kv + static_cast(num_kv_splits), // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute @@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba // Assumes device 0 when getting sm_count. arguments.hw_info.sm_count = sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; - arguments.split_kv = num_kv_splits; + arguments.split_kv = static_cast(num_kv_splits); MlaSm100Type::Fmha::set_split_kv(arguments); return MlaSm100Type::Fmha::get_workspace_size(arguments); diff --git a/csrc/cache.h b/csrc/cache.h index 0970b704be3ab..fb0c353b96137 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 131dcb15cd7e9..b3a985c2d5bbb 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { // grid is launched with dimensions (batch, num_splits) -template -__global__ void gather_cache( - const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, +template +__global__ void gather_and_maybe_dequant_cache( + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, // ENTRIES...] scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] @@ -634,6 +634,7 @@ __global__ void gather_cache( const int32_t block_size, const int32_t entry_size, const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const float* __restrict__ scale, const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per // batch @@ -675,10 +676,16 @@ __global__ void gather_cache( if (partial_block_size) full_blocks_end -= 1; } - auto copy_entry = [&](const scalar_t* __restrict__ _src, + auto copy_entry = [&](const cache_t* __restrict__ _src, scalar_t* __restrict__ _dst) { - for (int i = threadIdx.x; i < entry_size; i += blockDim.x) - _dst[i] = _src[i]; + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + _dst[i] = static_cast(_src[i]); + } else { + _dst[i] = + fp8::scaled_convert(_src[i], *scale); + } + } }; for (int pid = split_start; pid < full_blocks_end; ++pid) { @@ -705,25 +712,31 @@ __global__ void gather_cache( } // namespace vllm // Macro to dispatch the kernel based on the data type. -#define CALL_GATHER_CACHE(CPY_DTYPE) \ - vllm::gather_cache<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst.data_ptr()), \ - block_table.data_ptr(), cu_seq_lens.data_ptr(), \ - block_size, entry_size, block_table_stride, cache_block_stride, \ - cache_entry_stride, dst_entry_stride, seq_starts_ptr); +// SCALAR_T is the data type of the destination tensor. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, \ + reinterpret_cast(scale.data_ptr()), seq_starts_ptr); // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence // - Optionally, seq_starts (if provided) offsets the starting block index by // (seq_starts[bid] / page_size) -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, std::optional seq_starts = std::nullopt) { at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -761,20 +774,8 @@ void gather_cache( dim3 grid(batch_size, num_splits); dim3 block(1024); - TORCH_CHECK(src_cache.dtype() == dst.dtype(), - "src_cache and dst must have the same dtype"); - - const int dtype_bits = src_cache.element_size() * 8; const int32_t* seq_starts_ptr = seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; - if (dtype_bits == 32) { - CALL_GATHER_CACHE(uint32_t); - } else if (dtype_bits == 16) { - CALL_GATHER_CACHE(uint16_t); - } else if (dtype_bits == 8) { - CALL_GATHER_CACHE(uint8_t); - } else { - TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); - } + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); } diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 3952c43cbc727..982f7c07a13bd 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec { explicit FP16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec { (__m128i)vec8_data.reg, 1)) {} void save(void* ptr) const { - *reinterpret_cast<__m256i*>(ptr) = reg_low; - *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high; + _mm256_storeu_si256((__m256i*)ptr, reg_low); + _mm256_storeu_si256((__m256i*)ptr + 1, reg_high); } }; #endif diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp new file mode 100644 index 0000000000000..f3f00edb36068 --- /dev/null +++ b/csrc/cpu/dnnl_helper.cpp @@ -0,0 +1,346 @@ +#include +#include + +#include "common/memory_desc.hpp" +#include "common/memory.hpp" + +#include "dnnl_helper.h" + +static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; +} + +static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; +} + +void release_dnnl_matmul_handler(int64_t handler) { + DNNLMatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + delete ptr; +} + +template +class DNNLPrimitiveCache { + public: + using cache_value_t = std::pair; + using result_value_t = VT; + using container_t = std::list; + using value_iterator_t = typename container_t::iterator; + using map_t = std::unordered_map; + using creator_t = VT (*)(); + + public: + DNNLPrimitiveCache(size_t capacity) + : capacity_(capacity), + values_(), + key_to_value_(std::min(256lu, capacity)) { + assert(capacity > 0); + } + + template + result_value_t get_or_create(const KT& key, F&& creator) { + std::optional value = get_value(key); + if (value.has_value()) { + return value.value()->second; + } else { + return add_value({key, creator()})->second; + } + } + + size_t size() const { return values_.size(); } + + private: + void dump_data() { + std::stringstream ss; + ss << "table_id: " << std::hex << reinterpret_cast(this) << std::dec + << "\n"; + ss << "container: ["; + for (auto&& iter : values_) { + ss << "(" << iter.first << ", " << std::hex + << reinterpret_cast(iter.second.get()) << "), " << std::dec; + } + ss << "]\n"; + + ss << "map: ["; + for (auto&& iter : key_to_value_) { + ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex + << reinterpret_cast(iter.second->second.get()) << std::dec + << "), "; + } + ss << "]\n"; + std::printf("%s\n", ss.str().c_str()); + } + + value_iterator_t add_value(cache_value_t&& new_value) { + if (size() == capacity_) { + cache_value_t& last_item = values_.back(); + key_to_value_.erase(last_item.first); + values_.pop_back(); + } + + auto& added_value_ = values_.emplace_front(std::move(new_value)); + key_to_value_.emplace(added_value_.first, values_.begin()); + return values_.begin(); + } + + std::optional get_value(const KT& key) { + if (key_to_value_.size() > 0 && key == values_.begin()->first) { + return values_.begin(); + } + + auto value_map_iterator = key_to_value_.find(key); + if (value_map_iterator != key_to_value_.end()) { + values_.splice(values_.begin(), values_, value_map_iterator->second); + return value_map_iterator->second; + } else { + return {}; + } + } + + private: + const size_t capacity_; + container_t values_; + map_t key_to_value_; +}; + +DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( + const Args& args, dnnl::memory::data_type b_type) + : b_n_size_(args.b_n_size), + b_n_stride_(args.b_n_stride), + b_k_size_(args.b_k_size), + b_k_stride_(args.b_k_stride), + b_type_(b_type), + c_type_(args.c_type), + runtime_memory_ptrs_(8), + primitive_cache_size_(args.primitive_cache_size) { + assert(primitive_cache_size_ > 0); +} + +void DNNLMatMulPrimitiveHandler::prepack_weight( + void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); + dnnl::memory packed_weight(b_target_mem_desc, default_engine()); + { + dnnl::reorder(original_weight, packed_weight) + .execute(default_stream(), original_weight, packed_weight); + default_stream().wait(); + } + memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight; + b_target_mem_desc_ = b_target_mem_desc; +} + +void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr( + size_t index, dnnl_memory* memory_ptr) { + dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage(); + dnnl_memory_desc* mem_desc = const_cast(memory_ptr->md()); + runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc}; +} + +std::pair +DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) { + return runtime_memory_ptrs_[index]; +} + +namespace std { +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^ + hash()(static_cast(val.a_qs)) ^ + hash()(static_cast(val.b_qs)) ^ hash()(val.use_azp) ^ + hash()(static_cast(val.c_type)); + } +}; + +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; +} // namespace std + +bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && + l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp && + l.c_type == r.c_type; +} + +bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size && + l.bias_type == r.bias_type; +} + +static std::shared_ptr +get_w8a8_class_primitive_cache( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), + dnnl::memory::data_type::s8), + use_azp_(args.use_a_zero_point), + a_qs_(args.a_quantization_strategy), + b_qs_(args.b_quantization_strategy), + m_size_cache_(nullptr) { + assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL); + assert(b_qs_ != QuantizationStrategy::PER_TOKEN); + if (a_qs_ == QuantizationStrategy::PER_TOKEN) { + assert(!use_azp_); + }; + prepack_weight(args.b_ptr, + create_primitive_desc( + MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2); + a_scale_storage->set_data_handle((void*)args.a_scales_ptr); + } + if (use_azp_) { + auto&& [a_zero_point_storage, a_zero_point_mem_desc] = + get_runtime_memory_ptr(3); + a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr); + } + + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4); + bias_storage->set_data_handle((void*)args.bias_ptr); + } + + dnnl::matmul matmul = get_matmul_cache(args); + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, + .b_k_size = b_k_size_, + .a_qs = a_qs_, + .b_qs = b_qs_, + .use_azp = use_azp_, + .c_type = c_type_}; + m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_); + } + + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + return dnnl::matmul(desc); + }); +} + +void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get()); + if (use_azp_) { + memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get()); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), + (void*)args.b_scales_ptr); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), (void*)args.b_scales_ptr); + } + + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); +} + +dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md({key.a_m_size, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab); + dnnl::memory::desc b_md; + if (first_time) { + b_md = + dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8, + dnnl::memory::format_tag::any); + } else { + b_md = b_target_mem_desc_; + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_SRC, 0); + if (use_azp_) { + attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + if (key.use_bias) { + // For PER_TOKEN, bias will be applied in epilogue + assert(a_qs_ == QuantizationStrategy::PER_TENSOR); + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h new file mode 100644 index 0000000000000..54ceefced9e98 --- /dev/null +++ b/csrc/cpu/dnnl_helper.h @@ -0,0 +1,169 @@ +#ifndef DNNL_HELPER_H +#define DNNL_HELPER_H + +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace c10 { +struct BFloat16; +struct Half; +} // namespace c10 + +namespace dnnl { +namespace impl { +struct memory_storage_t; +struct matmul_pd_t; +struct matmul_desc_t; +} // namespace impl +} // namespace dnnl +struct dnnl_memory_desc; + +template +class DNNLPrimitiveCache; + +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} + +class DNNLMatMulPrimitiveHandler { + public: + virtual ~DNNLMatMulPrimitiveHandler() = default; + + protected: + struct Args { + dnnl_dim_t b_n_size; + dnnl_dim_t b_n_stride; + dnnl_dim_t b_k_size; + dnnl_dim_t b_k_stride; + void* b_ptr; + dnnl::memory::data_type c_type; + size_t primitive_cache_size; + }; + + protected: + DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); + + void prepack_weight(void* original_b_ptr, + dnnl::memory::desc b_target_mem_desc); + + void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); + + std::pair + get_runtime_memory_ptr(size_t index); + + protected: + const dnnl_dim_t b_n_size_; + const dnnl_dim_t b_n_stride_; + const dnnl_dim_t b_k_size_; + const dnnl_dim_t b_k_stride_; + dnnl::memory::data_type b_type_; + dnnl::memory::data_type c_type_; + std::unordered_map memory_cache_; + std::vector> + runtime_memory_ptrs_; + dnnl::memory::desc b_target_mem_desc_; + int64_t primitive_cache_size_; +}; + +class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL }; + + struct Args : public DNNLMatMulPrimitiveHandler::Args { + bool use_a_zero_point; + QuantizationStrategy a_quantization_strategy; + QuantizationStrategy b_quantization_strategy; + float* b_scales_ptr; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + QuantizationStrategy a_qs; + QuantizationStrategy b_qs; + bool use_azp; + dnnl::memory::data_type c_type; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const int8_t* a_ptr; + const float* a_scales_ptr; + const int32_t* a_zero_points_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + W8A8MatMulPrimitiveHandler(const Args& args); + + QuantizationStrategy get_input_scale_strategy() const { return a_qs_; } + + bool get_input_use_zero_point() const { return use_azp_; } + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + const bool use_azp_; + const QuantizationStrategy a_qs_; + const QuantizationStrategy b_qs_; + std::shared_ptr m_size_cache_; +}; + +#endif diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp deleted file mode 100644 index 1cb8dc5b25a66..0000000000000 --- a/csrc/cpu/dnnl_helper.hpp +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef DNNL_HELPER_HPP -#define DNNL_HELPER_HPP - -#include -#include - -#include "oneapi/dnnl/dnnl.hpp" - -namespace { -template -struct DNNLType { - static constexpr dnnl::memory::data_type type = - dnnl::memory::data_type::undef; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; -}; - -template -constexpr inline dnnl::memory::data_type get_dnnl_type() { - return DNNLType>::type; -} -}; // namespace - -template -class DNNLPrimitiveHelper { - public: - // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) - // A: [M, K], row-major - // B: [K, N], column-major - // C: [M, N], row-major - // bias: [N], row-major, optional - // a_scales: [MS] - // b_scales: [NS] - // Note: Due to the limitation of oneDNN - // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is - // not supported. - - template - static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, - const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, - dnnl_dim_t K, const float* a_scales, - const float* b_scales, dnnl_dim_t MS, - dnnl_dim_t NS) { - auto&& OutputType = get_dnnl_type(); - auto&& BiasType = get_dnnl_type(); - - dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); - dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); - dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); - - dnnl::primitive_attr attr; - if constexpr (!InputNoScale) { - if (MS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_SRC, 0); - } else { - // per-token - TORCH_CHECK(false, "per-token quantization is unsupported."); - } - } - - if (NS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); - } else { - // per-channel - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); - } - - dnnl::matmul::primitive_desc matmul_pd; -// Create memory descriptors with format_tag::any for the primitive. This -// enables the matmul primitive to choose memory layouts for an -// optimized primitive implementation, and these layouts may differ from the -// ones provided by the user. -#ifdef __aarch64__ - auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, - dnnl::memory::format_tag::any); - auto mat_weights_md = dnnl::memory::desc( - {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); - auto mat_dst_md = - dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, - mat_weights_md, bias_md, - mat_dst_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc( - default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); - } -#else - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - bias_md, c_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - c_md, attr); - } -#endif - dnnl::matmul matmul(matmul_pd); - - auto& engine = default_engine(); - - dnnl::memory a_m(a_md, engine, (void*)a); - dnnl::memory b_m(b_md, engine, (void*)b); - dnnl::memory c_m(c_md, engine, (void*)c); - dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)a_scales); - dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)b_scales); - - auto& stream = default_stream(); - - auto mat_src_mem = a_m; - auto mat_weights_mem = b_m; - auto mat_dst_mem = c_m; -#ifdef __aarch64__ - if (matmul_pd.weights_desc() != b_m.get_desc()) { - mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); - dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); - } -#endif - if constexpr (InputNoScale) { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } else { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } - stream.wait(); - } - - private: - static dnnl::engine& default_engine() { - static dnnl::engine engine(dnnl::engine::kind::cpu, 0); - return engine; - } - - static dnnl::stream& default_stream() { - static dnnl::stream stream(default_engine()); - return stream; - } -}; -#endif diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp new file mode 100644 index 0000000000000..acc3b9ecde143 --- /dev/null +++ b/csrc/cpu/dnnl_kernels.cpp @@ -0,0 +1,494 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.h" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; +#endif + +template <> +struct KernelVecType { +#if defined(__powerpc64__) || defined(__s390x__) + // Power architecture-specific vector type + using load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures + using load_vec_type = vec_op::FP16Vec16; +#endif + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = azp_val; + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const int32_t* azp, + const float* azp_adj, const scalar_t* bias, + const int64_t num_tokens, + const int64_t hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + const int64_t thread_num = omp_get_max_threads(); + if (num_tokens > thread_num) { +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + const float* input_ptr = input + i * hidden_size; + scalar_t* output_ptr = output + i * hidden_size; + int64_t j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + for (; j < hidden_size - vec_elem_num; ++j) { + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j); + } + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j, hidden_size - j); + } + } else { + const int64_t vec_iteration = + (hidden_size + vec_elem_num - 1) / vec_elem_num; + const int64_t vec_iteration_per_thread = + (vec_iteration + thread_num - 1) / thread_num; + const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num; +#pragma omp parallel for schedule(static, 1) + for (int64_t i = 0; i < thread_num; ++i) { + const int64_t start = elem_num_per_thread * i; + const int64_t end = std::min(hidden_size, elem_num_per_thread + start); + for (int64_t j = 0; j < num_tokens; ++j) { + cvt_vec_t token_scale_vec(a_scale[j]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[j] * static_cast(azp[j]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + int64_t k = start; + const float* input_ptr = input + j * hidden_size; + scalar_t* output_ptr = output + j * hidden_size; + for (; k < end - vec_elem_num; k += vec_elem_num) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k); + } + if (k < end) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k, end - k); + } + } + } + } +} +} // namespace + +int64_t create_onednn_scaled_mm_handler( + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& b_scales, // [1] or [OC] + at::ScalarType output_type, bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(b_scales.is_contiguous()); + + W8A8MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + if (b_scales.numel() == 1) { + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + } else { + TORCH_CHECK_EQ(b_scales.numel(), b.size(1)); + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL; + } + args.b_scales_ptr = b_scales.data_ptr(); + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + if (dynamic_act_quant) { + // dynamic per-token, bias, A scales and A zps will be applied in outside. + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN; + args.use_a_zero_point = false; + } else { + // static per-tensor + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + args.use_a_zero_point = use_azp; + } + + VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler", + [&] { + if (dynamic_act_quant) { + args.c_type = get_dnnl_type(); + } else { + args.c_type = get_dnnl_type(); + } + }); + + return reinterpret_cast(new W8A8MatMulPrimitiveHandler(args)); +} + +void onednn_scaled_mm( + torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& a_scales, // [M] or [1] + const std::optional& azp, // [M] or [1] + const std::optional& azp_adj, // [M] or [1] + const std::optional& bias, // [N] + int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_scaled_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); + W8A8MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + const int32_t* azp_ptr = nullptr; + if (azp.has_value()) { + azp_ptr = azp->data_ptr(); + } + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + TORCH_CHECK_EQ(a_scales.numel(), 1); + } + + W8A8MatMulPrimitiveHandler::ExecArgs exec_args; + exec_args.a_ptr = a.data_ptr(); + exec_args.a_m_size = a.size(0); + exec_args.bias_ptr = nullptr; + exec_args.use_bias = false; + exec_args.a_scales_ptr = nullptr; + exec_args.a_zero_points_ptr = nullptr; + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] { + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + if (bias.has_value()) { + exec_args.bias_ptr = bias->data_ptr(); + exec_args.bias_type = get_dnnl_type(); + exec_args.use_bias = true; + } + exec_args.a_scales_ptr = a_scales.data_ptr(); + exec_args.a_zero_points_ptr = azp_ptr; + exec_args.c_ptr = c.data_ptr(); + ptr->execute(exec_args); + } else if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) { + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + exec_args.c_ptr = tmp_fp32_out.data_ptr(); + ptr->execute(exec_args); + if (bias.has_value()) { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + (scalar_t*)nullptr, c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, (scalar_t*)nullptr, + c.size(0), c.size(1)); + } + } + } else { + TORCH_CHECK(false, "invalid act quant type."); + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + const torch::Tensor& scale, std::optional const& azp) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); + + const int64_t stride = input.stride(0); + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + static_scaled_int8_quant_impl(input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), nullptr, + num_tokens, stride, hidden_size); + } + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + torch::Tensor& scale, // [batch, 1] + std::optional const& azp) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + const int64_t stride = input.stride(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, stride, + hidden_size); + } + }); +} diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp deleted file mode 100644 index 6e120b8d20a7e..0000000000000 --- a/csrc/cpu/quant.cpp +++ /dev/null @@ -1,951 +0,0 @@ -#include "cpu_types.hpp" -#include "dnnl_helper.hpp" - -namespace { -template -struct KernelVecType { - using load_vec_type = void; - using azp_adj_load_vec_type = void; - using cvt_vec_type = void; -}; - -template <> -struct KernelVecType { - using load_vec_type = vec_op::FP32Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) -template <> -struct KernelVecType { - using load_vec_type = vec_op::BF16Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; -#endif - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power architecture-specific vector type - using load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures - using load_vec_type = vec_op::FP16Vec16; -#endif - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if defined(__AVX512F__) || defined(__aarch64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#elif defined(__powerpc64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#else -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " - "support.") -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "dynamic_scaled_int8_quant_impl requires " - "AVX512/powerpc64/AArch64 support.") -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_with_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, - "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} -#endif -} // namespace - -void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Ideally we want to fuse the GEMM and the scale procedure with oneDNN - // JIT, the intermediate data is cached in registers or L1. But for now - // the oneDNN GEMM code generation only supports two quantization - // patterns: per-tensor or per-output-channel of weight. - // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * - // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN - // GEMM, then the per-token scale (and bias) is applied with the epilogue - // C=s_a * C_inter + bias. - torch::Tensor tmp_fp32_out = - torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - bias->data_ptr(), a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } else { - // Compute C=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - nullptr, a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - } - }); -} - -void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const torch::Tensor& azp_adj, // [OC] - const std::optional& azp, // [1] or [M] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_azp only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); - } - if (azp) { - TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); - } - TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); - - // azp & bias types - TORCH_CHECK(azp_adj.dtype() == torch::kInt32); - TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); - TORCH_CHECK(!bias || bias->dtype() == c.dtype(), - "currently bias dtype must match output dtype ", c.dtype()); - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } - } else { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C_inter=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), bias->data_ptr(), - a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), - b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); - } else { - // Compute C_inter=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - - // Compute C=C_inter - s_a * s_b * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } else { - // Per-Tensor - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } - } - }); -} - -// static-per-tensor quantization. -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale, - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp.has_value() || azp->numel() == 1); - - const int hidden_size = input.size(-1); - const int num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "static_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -// dynamic-per-token quantization. -void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale, // [..., 1] - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_ppc64le only supports INT8 inputs."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - // We dont need this - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - }); -} - -#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b20a054648428..c9f426bdf618a 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -6,25 +6,20 @@ std::string init_cpu_threads_env(const std::string& cpu_ids); -void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); +void release_dnnl_matmul_handler(int64_t handler); -void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const torch::Tensor& azp_adj, - const std::optional& azp, - const std::optional& bias); +int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b, + const torch::Tensor& b_scales, + at::ScalarType output_type, + bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size); -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); -#endif +void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& a_scales, + const std::optional& azp, + const std::optional& azp_adj, + const std::optional& bias, + int64_t handler); void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, @@ -151,8 +146,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); // Quantization -#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) +#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \ + defined(__powerpc64__) at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // Helper function to release oneDNN handlers + ops.def("release_dnnl_matmul_handler(int handler) -> ()", + &release_dnnl_matmul_handler); + + // Create oneDNN W8A8 handler + ops.def( + "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " + "output_type, bool dynamic_act_quant, bool use_azp, int " + "primitive_cache_size) -> int", + &create_onednn_scaled_mm_handler); + + // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization + ops.def( + "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, " + "Tensor? azp_adj, Tensor? bias, int handler) -> ()"); + ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm); // Compute int8 quantized tensor for given scaling factor. ops.def( @@ -168,50 +180,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); -#elif defined(__powerpc64__) - // Compute int8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); - ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); - - // Compute int8 quantized tensor and scaling factor - ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); - ops.impl("dynamic_scaled_int8_quant", torch::kCPU, - &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu new file mode 100644 index 0000000000000..78f7b3cc1aa25 --- /dev/null +++ b/csrc/moe/grouped_topk_kernels.cu @@ -0,0 +1,757 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +namespace cg = cooperative_groups; + +namespace vllm { +namespace moe { + +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t BLOCK_SIZE = 512; +constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; + +namespace warp_topk { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) { + return 0; + } + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) { + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + } + return res; +} + +template +int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { + int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; + int64_t n = std::max(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); + return max(cache_topk, + round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); +} + +template +struct BitonicMerge { + // input should be a bitonic sequence, and sort it to be a monotonic sequence + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) { + is_better = is_better_than(val, other_val, idx_arr[i], + idx_arr[other_i]); + } else { + is_better = is_better_than(val, other_val); + } + + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + + // ascending doesn't matter before merging since all we need is a bitonic + // sequence + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + + T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); // for min + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); // for max + } + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + // load and merge k sorted values + __device__ void load_sorted(T const* __restrict__ in, + idxT const* __restrict__ in_idx, idxT start) { + idxT idx = start + WARP_SIZE - 1 - lane_; + for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { + if (idx < start + k_) { + T t = in[idx]; + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(t, val_arr_[i], in_idx[idx], idx_arr_[i]); + } else { + is_better = is_better_than(t, val_arr_[i]); + } + if (is_better) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + } + + __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out[out_i] = val_arr_[i]; + out_idx[out_i] = idx_arr_[i]; + } + } + } + + __device__ void dumpIdx(idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out_idx[out_i] = idx_arr_[i]; + } + } + } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + + int const lane_; + idxT const k_; + T const dummy_; + +}; // end class WarpSort + +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; + + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T const* in, idxT start, idxT end) { + idxT const end_for_fullwarp = + round_up_to_multiple_of(end - start) + start; + for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) { + T val = (i < end) ? in[i] : dummy_; + add(val, i); + } + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) { + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + } else { + do_add = is_better_than(val, k_th_); + } + + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); + if (mask == 0) { + return; + } + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + + // after done(), smem is used for merging results among warps + __syncthreads(); + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) { + k_th_idx_ = + __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + + T& old = val_arr_[max_arr_len_ - 1]; + + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + } else { + is_better = is_better_than(val, old); + } + + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; // end class WarpSelect +} // namespace warp_topk + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__device__ void topk_with_k2(T* output, T const* input, + cg::thread_block_tile<32> const& tile, + int32_t const lane_id, + int const num_experts_per_group) { + // Get the top2 per thread + T largest = -INFINITY; + T second_largest = -INFINITY; + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + T value = input[i]; + if (value > largest) { + second_largest = largest; + largest = value; + } else if (value > second_largest) { + second_largest = value; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + largest = input[i]; + } + } + + __syncwarp(); // Ensure all threads have valid data before reduction + // Get the top2 warpwise + T max1 = cg::reduce(tile, largest, cg::greater()); + + T max2 = max1; + bool equal_to_max1 = (max1 == largest); + + int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1)); + + if (count_max1 == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + + if (lane_id == 0) { + *output = max1 + max2; + } +} + +template +__global__ void topk_with_k2_kernel(T* output, T* input, + int64_t const num_tokens, + int64_t const num_cases, + int64_t const n_group, + int64_t const num_experts_per_group) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + + int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; + if (case_id < num_cases) { + input += case_id * num_experts_per_group; + output += case_id; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + topk_with_k2(output, input, tile, lane_id, num_experts_per_group); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +__global__ void group_idx_and_topk_idx_kernel( + T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, + T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, + int64_t const topk_group, int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, bool renormalize, + double routed_scaling_factor) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t case_id = + blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token + scores_with_bias += case_id * num_experts; + scores += case_id * num_experts; + group_scores += case_id * n_group; + topk_values += case_id * topk; + topk_indices += case_id * topk; + + int32_t align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to + // store the target topk idx + int32_t* s_topk_idx = reinterpret_cast(smem_buf); + T* s_topk_value = + reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + + warp_id * topk; + s_topk_idx += warp_id * topk; + + T value = cuda::std::numeric_limits::min(); + T topk_group_value = cuda::std::numeric_limits::min(); + int32_t num_equalto_topkth_group; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { + // calculate group_idx + int32_t target_num_min = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group && + (isfinite(cuda_cast( + group_scores[lane_id])))) // The check is necessary to avoid + // abnormal input + { + value = group_scores[lane_id]; + } + + int count_equal_to_top_value = WARP_SIZE - n_group; + int pre_count_equal_to_top_value = 0; + // Use loop to find the largset top_group + while (count_equal_to_top_value < target_num_min) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = cuda::std::numeric_limits::min(); + } + pre_count_equal_to_top_value = count_equal_to_top_value; + count_equal_to_top_value = __popc(__ballot_sync( + FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + } + num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + } + __syncthreads(); + + warp_topk::WarpSelect + queue((int32_t)topk, -INFINITY); + + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = + (topk_group_value != cuda::std::numeric_limits::min()); + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = + (i < num_experts_per_group) && isfinite(cuda_cast( + scores_with_bias[offset + i])) + ? scores_with_bias[offset + i] + : cuda::std::numeric_limits::min(); + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + // Get the topk_idx + queue.dumpIdx(s_topk_idx); + __syncwarp(); + } + + // Load the valid score value + // Calculate the summation + float topk_sum = 1e-20; + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + T value = + i < topk + ? scores[s_topk_idx[i]] + : cuda_cast(0.0f); // Load the valid value of expert + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += reduce(tile, cuda_cast(value), cg::plus()); + } + } + + __syncthreads(); + + if (case_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + topk_indices[i] = s_topk_idx[i]; + topk_values[i] = cuda_cast(value); + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + topk_indices[i] = i; + topk_values[i] = cuda_cast(1.0f / topk); + } + } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, + IdxT* topk_indices, T* scores_with_bias, + int64_t const num_tokens, int64_t const num_experts, + int64_t const n_group, int64_t const topk_group, + int64_t const topk, bool const renormalize, + double const routed_scaling_factor, bool enable_pdl = false, + cudaStream_t const stream = 0) { + int64_t num_cases = num_tokens * n_group; + int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, + num_tokens, num_cases, n_group, num_experts / n_group); + + int64_t topk_with_k_group_num_blocks = + (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; + size_t dynamic_smem_in_bytes = + warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, + topk); + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, scores_with_bias, num_tokens, + n_group, topk_group, topk, num_experts, + num_experts / n_group, renormalize, routed_scaling_factor); +} + +#define INSTANTIATE_NOAUX_TC(T, IdxT) \ + template void invokeNoAuxTc( \ + T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ + T * scores_with_bias, int64_t const num_tokens, \ + int64_t const num_experts, int64_t const n_group, \ + int64_t const topk_group, int64_t const topk, bool const renormalize, \ + double const routed_scaling_factor, bool enable_pdl, \ + cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC(float, int32_t); +INSTANTIATE_NOAUX_TC(half, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); +} // end namespace moe +} // namespace vllm + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor) { + auto data_type = scores_with_bias.scalar_type(); + auto input_size = scores_with_bias.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); + TORCH_CHECK(num_experts % n_group == 0, + "num_experts should be divisible by n_group"); + TORCH_CHECK(n_group <= 32, + "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + + torch::Tensor group_scores = torch::empty( + {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); + + switch (data_type) { + case torch::kFloat16: + // Handle Float16 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kFloat32: + // Handle Float32 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kBFloat16: + // Handle BFloat16 + vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), + num_tokens, num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + default: + // Handle other data types + throw std::invalid_argument( + "Invalid dtype, only supports float16, float32, and bfloat16"); + break; + } + return {topk_values, topk_indices}; +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 661730c96867e..92fc280b362b9 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -22,6 +22,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor); #endif bool moe_permute_unpermute_supported(); diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 2922352a3f7cc..ca0c873f49d9f 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -45,8 +45,6 @@ void moe_permute( auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto permuted_experts_id = torch::empty_like(topk_ids); auto sorted_row_idx = torch::empty_like(inv_permuted_idx); - auto align_expert_first_token_offset = - torch::zeros_like(expert_first_token_offset); CubKeyValueSorter sorter{}; int64_t* valid_num_ptr = nullptr; @@ -85,12 +83,14 @@ void moe_permute( }); // get m_indices and update expert_first_token_offset with align block - getMIndices(get_ptr(expert_first_token_offset), - get_ptr(align_expert_first_token_offset), - get_ptr(m_indices), n_local_expert, align_block_size_value, - stream); + // this is only required for DeepGemm and not required for CUTLASS group gemm if (align_block_size.has_value()) { - // update align_expert_first_token_offset + auto align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); expert_first_token_offset.copy_(align_expert_first_token_offset); } } @@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, torch::Tensor& expert_first_token_offset, torch::Tensor& src_row_id2dst_row_id_map, torch::Tensor& m_indices) { - TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); + TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); } -void moe_unpermute(const torch::Tensor& input, - const torch::Tensor& topk_weights, torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indices, - const std::optional& expert_map, - int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, - torch::Tensor& permuted_input, - torch::Tensor& expert_first_token_offset, - torch::Tensor& src_row_id2dst_row_id_map, - torch::Tensor& m_indices) { +void moe_unpermute( + const torch::Tensor& permuted_hidden_states, + const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx, + const std::optional& expert_first_token_offset, int64_t topk, + torch::Tensor& hidden_states) { TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); } @@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} +} \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7e49f68f62438..8f33d6cd666fa 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -78,6 +78,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "output_tensor) -> ()"); m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); + // Apply grouped topk routing to select experts. + m.def( + "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " + "topk_group, int topk, bool renormalize, float " + "routed_scaling_factor) -> (Tensor, Tensor)"); + m.impl("grouped_topk", torch::kCUDA, &grouped_topk); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 3e29f0a973dd6..86fe848e2fd5a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, double threshold); +void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input, + double alpha = 1.702, double limit = 7.0); void gelu_new(torch::Tensor& out, torch::Tensor& input); @@ -154,15 +156,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); #ifndef USE_ROCM -torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const std::vector& codebook_partition_sizes, - const std::optional& bias); - -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, const torch::Tensor& codebooks, - const std::vector& codebook_partition_sizes); torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, @@ -236,6 +229,11 @@ void get_cutlass_moe_mm_data( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu deleted file mode 100644 index 79cd2c610b3c2..0000000000000 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ /dev/null @@ -1,597 +0,0 @@ -/* - * Modified by Neural Magic - * Adapted from https://github.com/Vahe1994/AQLM - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace vllm { -namespace aqlm { - -__global__ void Code1x16MatVec( - const int4* __restrict__ A, const int4* __restrict__ B, - int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, - const int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. -) { - int a_gl_stride = prob_k / 8 / 8; - int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - bool pred = a_gl_rd < prob_m; - - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. - auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; - } - } - - int b_gl_rd = 0; - int c_gl_wr = a_gl_rd; - a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; - int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; - - __shared__ int4 sh_b[32 * 9]; - float res = 0; - - int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); - while (iters--) { - // We pad shared memory to avoid bank conflicts during reads - __syncthreads(); - for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; - } - __syncthreads(); - b_gl_rd += 32 * 8; - - int b_sh_rd = 9 * (threadIdx.x % 32); - if (pred && a_gl_rd < a_gl_end) { - const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll - for (int i = 0; i < 8; i++) { - uint32_t dec[4]; - // We bypass the L1 cache to avoid massive amounts of memory streaming - // that doesn't actually help us; this brings > 2x speedup. - asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*)&codebook[enc[i]])); - half2* a = reinterpret_cast(&dec); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); - half2 res2 = {}; -#pragma unroll - for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); - res += __half2float(res2.x) + __half2float(res2.y); - b_sh_rd++; - } - a_gl_rd += 32; - } - } - - if (pred) { -#pragma unroll - for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); - if (threadIdx.x % 32 == 0) - reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); - } -} - -__global__ void Code2x8MatVec( - const int4* __restrict__ A, const int4* __restrict__ B, - int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. - -) { - int a_gl_stride = prob_k / 8 / 8; - int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - bool pred = a_gl_rd < prob_m; - - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. - auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; - } - } - - int b_gl_rd = 0; - int c_gl_wr = a_gl_rd; - a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; - int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; - int lane = threadIdx.x % 8; - - extern __shared__ int4 sh[]; - int4* sh_b = sh; - int4* sh_code = sh_b + 32 * 9; - int4* sh_code0 = sh_code; - int4* sh_code1 = sh_code + 256 * 8; - - for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { - int4 dec = codebook[i]; -#pragma unroll - for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; - } - __syncthreads(); - - float res = 0; - - int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); - while (iters--) { - // We pad shared memory to avoid bank conflicts during reads - __syncthreads(); - for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; - } - __syncthreads(); - b_gl_rd += 32 * 8; - - int b_sh_rd = 9 * (threadIdx.x % 32); - if (pred && a_gl_rd < a_gl_end) { - const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll - for (int i = 0; i < 8; i++) { - half2* a0 = - reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = - reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); - half2 res2 = {}; -#pragma unroll - for (int j = 0; j < 4; j++) - res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); - res += __half2float(res2.x) + __half2float(res2.y); - b_sh_rd++; - } - a_gl_rd += 32; - } - } - - if (pred) { -#pragma unroll - for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); - if (threadIdx.x % 32 == 0) - reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); - } -} - -__global__ void Code1x16Dequant( - const int4* __restrict__ A, int4* __restrict__ C, - const int4* __restrict__ codebook, int prob_m, int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long, sums to m. - const int codebook_stride // as int4 -) { - int a_gl_stride = prob_k / 8 / 8; - int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - bool pred = a_gl_rd < prob_m; - - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. - auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; - } - } - - a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; - int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; - - int c_gl_stride = prob_k / 8; - int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; - - int iters = (prob_k / 8 - 1) / (8 * 32) + 1; - while (iters--) { - if (pred && a_gl_rd < a_gl_end) { - const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll - for (int i = 0; i < 8; i++) { - int4 chunk; - auto dec = reinterpret_cast(&chunk); - // We bypass the L1 cache to avoid massive amounts of memory streaming - // that doesn't actually help us; this brings > 2x speedup. - asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*)&codebook[enc[i]])); - - C[a_gl_rd * 8 + i] = chunk; - } - } - a_gl_rd += 32; - } -} - -__global__ void Code2x8Dequant( - const int4* __restrict__ A, int4* __restrict__ C, - const int4* __restrict__ codebook, int prob_m, int prob_k, - const int4 - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long, corresponds to cols. - const int codebook_stride // as int4 -) { - int a_gl_stride = prob_k / 8 / 8; - int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - bool pred = a_gl_rd < prob_m; - - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. - auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; - } - } - - a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; - int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; - int lane = threadIdx.x % 8; - - int c_gl_stride = prob_k / 8; - int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; - - extern __shared__ int4 sh[]; - int4* sh_code = sh; - int4* sh_code0 = sh_code; - int4* sh_code1 = sh_code + 256 * 8; - - for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { - int4 dec = codebook[i]; -#pragma unroll - for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; - } - __syncthreads(); - - int iters = (prob_k / 8 - 1) / (8 * 32) + 1; - while (iters--) { - if (pred && a_gl_rd < a_gl_end) { - const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll - for (int i = 0; i < 8; i++) { - int4 chunk; - half2* a0 = - reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = - reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); -#pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); - C[a_gl_rd * 8 + i] = chunk; - } - } - a_gl_rd += 32; - } -} - -inline int ceildiv(int a, int b) { return (a + b - 1) / b; } - -const int THREAD_M = 16; - -void code1x16_matvec_cuda(const void* __restrict__ A, - const void* __restrict__ B, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, - int prob_k, const int4 codebook_a_sizes, - const int codebook_stride) { - int sms; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); - int waves = 0; - int thread_m; - do { - waves++; - thread_m = ceildiv(prob_m, waves * sms); - } while (thread_m > THREAD_M); - - int blocks = ceildiv(prob_m, thread_m); - int threads = 32 * thread_m; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16MatVec<<>>( - (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, - prob_k, codebook_a_sizes, codebook_stride); -} - -void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, - int prob_k, const int4 codebook_a_sizes, - const int codebook_stride) { - int sms; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); - int waves = 0; - int thread_m; - do { - waves++; - thread_m = ceildiv(prob_m, waves * sms); - } while (thread_m > THREAD_M); - - int blocks = ceildiv(prob_m, thread_m); - int threads = 32 * thread_m; - int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaFuncSetAttribute(Code2x8MatVec, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code2x8MatVec<<>>( - (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, - prob_k, codebook_a_sizes, codebook_stride); -} - -void code1x16_dequant_cuda( - const void* __restrict__ A, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. -) { - int sms; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); - int waves = 0; - int thread_m; - do { - waves++; - thread_m = ceildiv(prob_m, waves * sms); - } while (thread_m > THREAD_M); - - int blocks = ceildiv(prob_m, thread_m); - int threads = 32 * thread_m; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16Dequant<<>>( - (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long. - codebook_stride // as int4. - ); -} - -// Dequantizes the code and codebook into weights. -void code2x8_dequant_cuda( - const void* __restrict__ A, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, int prob_k, - const int4 - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long, corresponds to cols. - const int codebook_stride // as int4 -) { - int sms; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); - int waves = 0; - int thread_m; - do { - waves++; - thread_m = ceildiv(prob_m, waves * sms); - } while (thread_m > THREAD_M); - - int blocks = ceildiv(prob_m, thread_m); - int threads = 32 * thread_m; - int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - cudaFuncSetAttribute(Code2x8Dequant, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared); - Code2x8Dequant<<>>( - (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, - codebook_a_sizes, codebook_stride); -} - -int codebook_stride(const torch::Tensor& codebooks) { - return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); -} - -void code1x16_matvec( - const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes // cumulative sizes of A spanning each - // codebook, at most 3 long. -) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - int prob_m = C.size(0); - int prob_k = B.size(0); - - code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), - codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, - codebook_stride(codebook)); -} - -torch::Tensor code1x16_matmat(const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { - auto input_sizes = input.sizes(); - auto out_features = codes.size(0) * codebooks.size(2); - auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty( - {flat_input.size(0), out_features}, - torch::TensorOptions().dtype(input.dtype()).device(input.device())); - - for (int i = 0; i < flat_input.size(0); ++i) { - auto input_vec = flat_input.index({i}); - auto output_vec = flat_output.index({i}); - code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, - codebook_a_sizes); - } - flat_output *= scales.flatten().unsqueeze(0); - - if (bias.has_value()) { - flat_output += bias->unsqueeze(0); - } - - auto output_sizes = input_sizes.vec(); - output_sizes.pop_back(); - output_sizes.push_back(-1); - auto output = flat_output.reshape(output_sizes); - return output; -} - -void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, - torch::Tensor& C, const torch::Tensor& codebook, - const int4 codebook_a_sizes) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - int prob_m = C.size(0); - int prob_k = B.size(0); - code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), - codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, - 2 * codebook_stride(codebook)); -} - -torch::Tensor code2x8_matmat(const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { - auto input_sizes = input.sizes(); - auto out_features = codes.size(0) * codebooks.size(2); - auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty( - {flat_input.size(0), out_features}, - torch::TensorOptions().dtype(input.dtype()).device(input.device())); - - for (int i = 0; i < flat_input.size(0); ++i) { - auto input_vec = flat_input.index({i}); - auto output_vec = flat_output.index({i}); - code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, - codebook_a_sizes); - } - flat_output *= scales.flatten().unsqueeze(0); - if (bias.has_value()) { - flat_output += bias->unsqueeze(0); - } - - auto output_sizes = input_sizes.vec(); - output_sizes.pop_back(); - output_sizes.push_back(-1); - auto output = flat_output.reshape(output_sizes); - return output; -} - -// Accumulate the partition sizes. -int4 accumulate_sizes(const std::vector& codebook_partition_sizes) { - int4 cumulative_sizes; - auto cumulative_size = &cumulative_sizes.x; - size_t i = 0; - int last = 0; - assert(codebook_partition_sizes.size() <= 4); - for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) { - *cumulative_size = codebook_partition_sizes[i] + last; - last = *cumulative_size; - } - // fill in the rest with unreachable. - for (; i < 4; ++i, ++cumulative_size) { - *cumulative_size = last * 10; - } - return cumulative_sizes; -} - -} // namespace aqlm -} // namespace vllm - -torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const std::vector& codebook_partition_sizes, - const std::optional& bias) { - int4 cumulative_sizes = - vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); - int const entries = codebooks.size(1); - - if (nbooks == 1 && entries == (1 << 16)) { - return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, - cumulative_sizes, bias); - } - if (nbooks == 2 && entries == (1 << 8)) { - return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, - cumulative_sizes, bias); - } - - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, - " entries is not currently supported.") - return {}; -} - -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, const torch::Tensor& codebooks, - const std::vector& codebook_partition_sizes) { - int4 cumulative_sizes = - vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); - int const entries = codebooks.size(1); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); - int rows = codes.size(1); - int cols = codes.size(0); - - auto in_features = codes.size(1) * 8; - auto out_features = codes.size(0); - - assert(out_features == std::accumulate(codebook_partition_sizes.begin(), - codebook_partition_sizes.end(), 0)); - - auto weights = torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device())); - - if (nbooks == 1 && entries == (1 << 16)) { - vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), - codebooks.data_ptr(), out_features, - in_features, cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower - // and not consistent with gemv implementation.) weights *= - // scales.index({"...", 0, 0}); - - return weights; - } - - if (nbooks == 2 && entries == (1 << 8)) { - vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), - codebooks.data_ptr(), out_features, - in_features, cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower - // and not consistent with gemv implementation) weights *= - // scales.index({"...", 0, 0}); - - return weights; - } - - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, - " entries is not currently supported.") - return {}; -} diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu new file mode 100644 index 0000000000000..fdac47c425d61 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -0,0 +1,418 @@ +// +// Based off of: +// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +// + +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" + +#include "core/registration.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "cutlass_extensions/common.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm::cutlass_w4a8 { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using MmaType = cutlass::float_e4m3_t; // A/scale element type +using QuantType = cutlass::int4b_t; // B element type (packed int4) + +static int constexpr TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr ScalePackSize = 8; // pack 8 scale elements together +static int constexpr PackFactor = 8; // 8 4-bit packed into int32 + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) +using StrideA = cutlass::detail::TagToStrideA_t; + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) +using StrideB = cutlass::detail::TagToStrideB_t; + +// Define the CuTe layout for reordered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout, StrideB>{})); + +// Group-wise scales +using ElementScale = MmaType; +using LayoutScale = cutlass::layout::RowMajor; + +// Per-tok, per-chan scales +using ElementSChannel = float; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch + // based on the default + // setting in the + // Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +// ---------------------------------------------------------------------------- +// Kernel template — Tile/Cluster shapes +// ---------------------------------------------------------------------------- +template +struct W4A8GemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // Epilogue per-tok, per-chan scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementSChannel, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C + // matrix. We can enable this if beta == 0 by changing ElementC to + // void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, + AlignmentC, ElementD, + typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule, // This is the only epi supporting the required + // swap + transpose. + EVTCompute>::CollectiveOp; + + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopShuffled, CollectiveEpilogue>; + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::StrideC; + using StrideD = typename GemmKernelShuffled::StrideD; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + + static torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type) { + // TODO: param validation + int m = A.size(0); + int k = A.size(1); + int n = B.size(1); + + // Allocate output + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto device = A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + torch::Tensor D = + torch::empty({m, n}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + // prepare arg pointers + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.data_ptr()); + // can we avoid harcode the 8 here + auto S_ptr = + static_cast const*>( + group_scales.const_data_ptr()); + + // runtime layout for B + auto shape_B = cute::make_shape(n, k, 1); + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + // strides + int const scale_k = cutlass::ceil_div(k, group_size); + StrideA stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + // Reverse stride here due to swap and transpose + StrideD stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(n, scale_k, 1)); + + // Create a structure of gemm kernel arguments suitable for invoking an + // instance of Gemm auto arguments = + // args_from_options(options); + /// Populates a Gemm::Arguments structure from the given arguments + /// Swap the A and B tensors, as well as problem shapes here. + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + + MainloopArguments mainloop_arguments{ + B_ptr, layout_B_reordered, A_ptr, stride_A, + S_ptr, stride_S, group_size}; + + EpilogueArguments epilogue_arguments{ + ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), + nullptr, + {}, // no C + D_ptr, + stride_D}; + + Args arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, // shape + mainloop_arguments, + epilogue_arguments}; + + // Workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + + return D; + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Kernel_256x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x256_2x1x1 = + W4A8GemmKernel, Shape<_2, _1, _1>>; +using Kernel_128x256_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; + +torch::Tensor mm_dispatch(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + const std::string& schedule) { + if (schedule == "256x128_1x1x1") { + return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x64_1x1x1") { + return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x32_1x1x1") { + return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x16_1x1x1") { + return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_2x1x1") { + return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_1x1x1") { + return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x128_1x1x1") { + return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x64_1x1x1") { + return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x32_1x1x1") { + return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x16_1x1x1") { + return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } + TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); + return {}; +} + +torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + std::optional maybe_schedule) { + // requested a specific schedule + if (maybe_schedule) { + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, *maybe_schedule); + } + std::string schedule; + int M = A.size(0); + int K = A.size(1); + int N = B.size(1); + // heuristic + if (M <= 16) { + schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1"; + } else if (M <= 32) { + schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1"; + } else if (M <= 64) { + if (K == 16384 && N == 18432) + schedule = "256x64_1x1x1"; + else if (N <= 8192 && K <= 8192) + schedule = "128x32_1x1x1"; + else + schedule = "128x64_1x1x1"; + } else if (M <= 128) { + if (K == 16384 && N == 18432) + schedule = "256x128_1x1x1"; + else if (N <= 8192) + schedule = "128x64_1x1x1"; + else + schedule = "128x128_1x1x1"; + } else if (M <= 256) { + if (N <= 4096) + schedule = "128x64_1x1x1"; + else if (N <= 8192) + schedule = "128x128_1x1x1"; + else + schedule = "128x256_1x1x1"; + } else if (M <= 512 && N <= 4096) { + schedule = "128x128_1x1x1"; + } else if (M <= 1024) { + schedule = "128x256_1x1x1"; + } else { + schedule = "128x256_2x1x1"; + } + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, schedule); +} + +// ---------------------------------------------------------------------------- +// Pre-processing utils +// ---------------------------------------------------------------------------- +torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { + TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(scales.is_cuda()); + + auto packed_scales = torch::empty( + {scales.numel() * ScalePackSize}, + torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); + auto scales_ptr = static_cast(scales.const_data_ptr()); + auto packed_scales_ptr = + static_cast*>( + packed_scales.data_ptr()); + + cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel()); + + return packed_scales; +} + +torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { + TORCH_CHECK(B.dtype() == torch::kInt32); + TORCH_CHECK(B.dim() == 2); + + torch::Tensor B_packed = torch::empty_like(B); + + int k = B.size(0) * PackFactor; // logical k + int n = B.size(1); + + auto B_ptr = static_cast(B.const_data_ptr()); + auto B_packed_ptr = static_cast(B_packed.data_ptr()); + auto shape_B = cute::make_shape(n, k, 1); + auto layout_B = make_layout(shape_B, LayoutRight{}); // row major + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); + + return B_packed; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_mm", &mm); + m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); + m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8 \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 6c6e89790847f..15bb2c300543c 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -10,7 +10,7 @@ template __global__ void get_group_gemm_starts( - int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets, ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, ElementAB* b_base_as_int, ElementC* out_base_as_int, @@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts( else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ get_group_gemm_starts \ <<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ static_cast(a_ptrs.data_ptr()), \ static_cast(b_ptrs.data_ptr()), \ static_cast(out_ptrs.data_ptr()), \ @@ -61,6 +61,8 @@ void run_get_group_gemm_starts( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + // expect int64_t to avoid overflow during offset calculations + TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); int num_experts = static_cast(expert_offsets.size(0)); bool per_act_token = a_scales.numel() != 1; diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 857cca1e82df7..49cafcc32adc6 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, } } +namespace { +inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& atomic_buffer, + int64_t num_experts, int64_t n, + int64_t k, cudaStream_t stream, + const bool swap_ab) { + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + + const int32_t* topk_ptr = static_cast(topk_ids.data_ptr()); + int32_t* ps1_ptr = static_cast(problem_sizes1.data_ptr()); + int32_t* ps2_ptr = static_cast(problem_sizes2.data_ptr()); + int32_t* atomic_ptr = static_cast(atomic_buffer.data_ptr()); + + if (swap_ab) { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } else { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } +} +} // namespace + +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + // Swap-AB should be disabled for FP4 path + bool may_swap_ab = (!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD); + + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); +} + void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller( bool may_swap_ab = (!blockscale_offsets.has_value()) && (topk_ids.numel() <= SWAP_AB_THRESHOLD); - if (may_swap_ab) { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } else { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); if (blockscale_offsets.has_value()) { // fp4 path @@ -161,6 +196,7 @@ void get_cutlass_moe_mm_data_caller( topk_ids.size(1)); } +template __global__ void compute_pplx_data(int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2, @@ -168,14 +204,23 @@ __global__ void compute_pplx_data(int32_t* expert_offsets, const int padded_m, const int n, const int k) { int expert_idx = threadIdx.x; - expert_offsets[expert_idx] = expert_idx * padded_m; - problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx]; - problem_sizes1[expert_idx * 3 + 1] = 2 * n; - problem_sizes1[expert_idx * 3 + 2] = k; - problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx]; - problem_sizes2[expert_idx * 3 + 1] = k; - problem_sizes2[expert_idx * 3 + 2] = n; + + if constexpr (!SWAP_AB) { + problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx]; + problem_sizes1[expert_idx * 3 + 1] = 2 * n; + problem_sizes1[expert_idx * 3 + 2] = k; + problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx]; + problem_sizes2[expert_idx * 3 + 1] = k; + problem_sizes2[expert_idx * 3 + 2] = n; + } else { + problem_sizes1[expert_idx * 3] = 2 * n; + problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx]; + problem_sizes1[expert_idx * 3 + 2] = k; + problem_sizes2[expert_idx * 3] = k; + problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx]; + problem_sizes2[expert_idx * 3 + 2] = n; + } } void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, @@ -187,10 +232,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, const int64_t n, const int64_t k) { auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); - compute_pplx_data<<<1, num_local_experts, 0, stream>>>( - static_cast(expert_offsets.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(expert_num_tokens.data_ptr()), padded_m, n, - k); + if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) { + compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + static_cast(expert_offsets.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(expert_num_tokens.data_ptr()), padded_m, n, + k); + } else { + compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + static_cast(expert_offsets.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(expert_num_tokens.data_ptr()), padded_m, n, + k); + } } \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 106bacb4883cb..84843ee6e0949 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data( version_num, ". Required capability: 90 or 100"); } +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + int32_t version_num = get_sm_version_num(); +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) + get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, + problem_sizes2, num_experts, n, k, + blockscale_offsets); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " + "kernel for CUDA device capability: ", + version_num, ". Required capability: 90 or 100"); +} + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 9af7833d09f32..0d14ba15937c6 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -349,9 +349,12 @@ def to_cute_constant(value: list[int]): def unique_schedules(impl_configs: list[ImplConfig]): - return list( - set(sch for impl_config in impl_configs - for sch in impl_config.schedules)) + # Use dict over set for deterministic ordering + return list({ + sch: None + for impl_config in impl_configs + for sch in impl_config.schedules + }.keys()) def unsigned_type_with_bitwidth(num_bits): @@ -568,78 +571,79 @@ def generate(): itertools.repeat(default_heuristic)) ] - # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) - # TODO (LucasWilkinson): Further tuning required - qqq_tile_heuristic_config = { - #### M = 257+ - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), - # "M > 256": ((128, 256), (2, 1, 1)), - "M > 256": ((128, 128), (2, 1, 1)), - #### M = 129-256 - "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), - "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 128": ((128, 256), (2, 1, 1)), - "M > 128": ((128, 128), (2, 1, 1)), - #### M = 65-128 - "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), - "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), - "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), - "M > 64": ((128, 128), (2, 1, 1)), - #### M = 33-64 - "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), - # Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), - "M > 32": ((128, 64), (2, 1, 1)), - #### M = 17-32 - "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), - "M > 16": ((256, 32), (2, 1, 1)), - #### M = 1-16 - "N >= 26624": ((256, 16), (1, 1, 1)), - None: ((128, 16), (1, 1, 1)), - } + # TODO: Support W4A8 when ready + # # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # # TODO (LucasWilkinson): Further tuning required + # qqq_tile_heuristic_config = { + # #### M = 257+ + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # # "M > 256": ((128, 256), (2, 1, 1)), + # "M > 256": ((128, 128), (2, 1, 1)), + # #### M = 129-256 + # "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + # "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 128": ((128, 256), (2, 1, 1)), + # "M > 128": ((128, 128), (2, 1, 1)), + # #### M = 65-128 + # "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + # "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + # "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + # "M > 64": ((128, 128), (2, 1, 1)), + # #### M = 33-64 + # "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # # Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + # "M > 32": ((128, 64), (2, 1, 1)), + # #### M = 17-32 + # "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + # "M > 16": ((256, 32), (2, 1, 1)), + # #### M = 1-16 + # "N >= 26624": ((256, 16), (1, 1, 1)), + # None: ((128, 16), (1, 1, 1)), + # } - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - qqq_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore - for cond, tile_config in qqq_tile_heuristic_config.items() - ] + # # For now we use the same heuristic for all types + # # Heuristic is currently tuned for H100s + # qqq_heuristic = [ + # (cond, ScheduleConfig(*tile_config, + # **sch_common_params)) # type: ignore + # for cond, tile_config in qqq_tile_heuristic_config.items() + # ] - QQQ_kernel_types = [ - *(TypeConfig( - a=DataType.s8, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.s32, - ) for b_group_scale in (DataType.f16, DataType.void)), - *(TypeConfig( - a=DataType.e4m3, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.f32, - ) for b_group_scale in (DataType.f16, DataType.void)), - ] + # QQQ_kernel_types = [ + # *(TypeConfig( + # a=DataType.s8, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.s32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # *(TypeConfig( + # a=DataType.e4m3, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.f32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # ] - impl_configs += [ - ImplConfig(x[0], x[1], x[2]) - for x in zip(QQQ_kernel_types, - itertools.repeat(get_unique_schedules(qqq_heuristic)), - itertools.repeat(qqq_heuristic)) - ] + # impl_configs += [ + # ImplConfig(x[0], x[1], x[2]) + # for x in zip(QQQ_kernel_types, + # itertools.repeat(get_unique_schedules(qqq_heuristic)), + # itertools.repeat(qqq_heuristic)) + # ] output_dir = os.path.join(SCRIPT_DIR, "generated") diff --git a/csrc/quantization/marlin/dense/LICENSE b/csrc/quantization/marlin/dense/LICENSE deleted file mode 100644 index 1d1e4cf9c8233..0000000000000 --- a/csrc/quantization/marlin/dense/LICENSE +++ /dev/null @@ -1,209 +0,0 @@ -Contains code from https://github.com/IST-DASLab/marlin - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. diff --git a/csrc/quantization/marlin/dense/common/base.h b/csrc/quantization/marlin/dense/common/base.h deleted file mode 100644 index 68c83d5478cf8..0000000000000 --- a/csrc/quantization/marlin/dense/common/base.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; diff --git a/csrc/quantization/marlin/dense/common/mem.h b/csrc/quantization/marlin/dense/common/mem.h deleted file mode 100644 index 64f9c393d77ce..0000000000000 --- a/csrc/quantization/marlin/dense/common/mem.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu deleted file mode 100644 index ea96326ed7e61..0000000000000 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_dense { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - auto s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == - -1) // for per-column quantization we finally apply the scale here - res = __hmul2(res, s[0]); - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - // Uncomment for debug - // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + - // ", thread_n = " + str(th_config.thread_n) + - // ", num_threads = " + str(th_config.num_threads) + " for - // MKN = [" + str(prob_m) + - // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace marlin_dense - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_dense::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_dense::tile_size)); - TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_dense::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_dense::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * - marlin_dense::pack_factor_4bit; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize - if (b_scales.size(0) != 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - } - int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_dense::min_thread_n)); - int min_workspace_size = - (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, marlin_dense::max_par); - - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_gemm", &marlin_gemm); -} diff --git a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu deleted file mode 100644 index c96d68d9b29aa..0000000000000 --- a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu +++ /dev/null @@ -1,1248 +0,0 @@ -/* - * Adapted from - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp - * Modified by HandH1998 - * Copyright (C) 2024 HandH1998 - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "../dense/common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "../dense/common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS_GROUP = Vec; // weight per-group quantization scales -using FragS_CHANNEL = - Vec; // weight per-channel quantization scales or activaton - // per-token quantization scales - -// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, -// cp.async.ca can support BYTES = 4, 8, 16; -// as s_tok's shape is equal to prob_m, we need set s_tok to float type, -// and cp_size = 1 float, i.e., 4 BYTES -// Asynchronous global->shared copy for activation quantizaton scales s_tok -__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 4; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.ca.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// m16n8k16 tensor core mma instruction with int8 inputs and int32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - int* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), - "r"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in int8 tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -inline __device__ half2 float2_to_half2(float2 f) { - uint32_t res; - // NOTE(HandH1998): h0,h1 should be uint16_t, not half - uint16_t h0, h1; - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); - return reinterpret_cast(res); -} - -inline __device__ float int32_to_float(int h) { - float res; - asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); - return res; -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per channel dequant. -__device__ inline FragB dequant_per_channel(int q) { - static constexpr int MASK = 0xf0f0f0f0; - FragB frag_b; - frag_b[0] = (q & MASK); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per group dequant. -__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { - static constexpr uint32_t LO = 0x000f000f; - static constexpr uint32_t HI = 0x00f000f0; - static constexpr uint32_t EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - static constexpr uint32_t SUB = 0x64086408; - static constexpr uint32_t MUL = 0x2c002c00; - static constexpr uint32_t ADD = 0xd480d480; - *reinterpret_cast(&t0) = __hsub2( - *reinterpret_cast(&t0), *reinterpret_cast(&SUB)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - - uint16_t s = reinterpret_cast(&frag_s)[i]; - uint32_t double_s; - // pack 2xfp16 to half2 - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); - // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 - // half, respectively) - static constexpr uint32_t MAGIC_NUM = 0x64806480; - *reinterpret_cast(&t0) = __hfma2( - *reinterpret_cast(&t0), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 - // int8 into 1 uint32 - FragB frag_b; - uint32_t uint8s; - static constexpr uint32_t MASK_0246 = 0x6420; - static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(uint8s) - : "r"(t0), "r"(t1), "n"(MASK_0246)); - frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); - return frag_b; -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if constexpr (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; - D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 16; - C += 16 * thread_m_blocks * prob_n / 4; - D += 16 * thread_m_blocks * prob_n / 8; - s_tok += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 16; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 1 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - constexpr int s_tok_sh_stride = 16 * thread_m_blocks; - - constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4; - - int s_group_gl_stride = prob_n / 8; - constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_group_sh_stage = s_group_sh_stride; - int s_group_gl_rd_delta = s_group_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix - int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); - a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - auto s_tok_gl_rd = threadIdx.x; - // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10, - // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for - // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as - // s_tok's size is not fixed, we can not shuffle before inference we shuffle - // it when fetching s_tok from global memory to shared memory, that's why - // s_tok_sh_wr is like this - int s_tok_sh_wr = - (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; - int s_tok_sh_rd = (threadIdx.x % 32) / 4; - bool s_tok_sh_wr_pred = threadIdx.x < prob_m; - - auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - auto s_ch_sh_wr = threadIdx.x; - int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - 2 * ((threadIdx.x % 32) % 4); - bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride; - - int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd; - bool s_group_sh_wr_pred; - if constexpr (group_blocks != -1) { - s_group_gl_rd = - s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_group_sh_stride * slice_col + threadIdx.x; - s_group_sh_wr = threadIdx.x; - // NOTE(HandH1998): s_group_sh_rd is related to mma output C - s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride; - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages * - // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s_tok = sh_b + (stages * b_sh_stage); - int4* sh_s_ch = sh_s_tok + s_tok_sh_stride; - int4* sh_s_group = sh_s_ch + s_ch_sh_stride; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS_GROUP frag_s_group[2][4]; - FragS_CHANNEL frag_s_tok[thread_m_blocks]; - FragS_CHANNEL frag_s_ch[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe; - if (s_group_sh_wr_pred) - cp_async4(&sh_s_group_stage[s_group_sh_wr], - &s_group[s_group_gl_rd]); - s_group_gl_rd += s_group_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - int4* sh_s_group_stage = - sh_s_group + - s_group_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s_group[k % 2])[0] = - sh_s_group_stage[s_group_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - // int b_quant_shift = b_quant << 4; - FragB frag_b0, frag_b1; - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0); - frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1); - } else { - int b_quant_shift = b_quant << 4; - frag_b0 = dequant_per_channel(b_quant); - frag_b1 = dequant_per_channel(b_quant_shift); - } - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - int* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - // global_reduce works on INT32 elements, which are the results of INT8 GEMM. - // This is why we need another INT32 maxtrix `C` to reduce instead of the - // original half matrix `D`. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 4; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 8 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; - c_gl_wr += (4 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads * 2; - auto c_sh_wr = 2 * threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i + 1], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2) + 1], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; - int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - reinterpret_cast(&d_red1)[j]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += - reinterpret_cast(&d_red2)[j]; - } - } - if (!last) { - int4 d1, d2; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d1)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d2)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - d1; - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + - 1] = d2; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int d_gl_stride = prob_n / 8; - constexpr int d_sh_stride = 2 * thread_n_blocks + 1; - int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int d_sh_rd_delta = - d_sh_stride * (threads / (2 * thread_n_blocks)); - - int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - d_gl_wr += (2 * thread_n_blocks) * slice_col; - int d_sh_wr = - (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - d_sh_wr += 32 * (threadIdx.x / 32); - int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int d_gl_wr_end = d_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { - float2 deq_res; - deq_res.x = int32_to_float(c0) * w_s[0] * a_s; - deq_res.y = int32_to_float(c1) * w_s[1] * a_s; - ((half2*)sh)[idx] = float2_to_half2(deq_res); - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = d_sh_wr + 8 * j; - write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - } - d_sh_wr += 16 * (4 * d_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (d_gl_wr < d_gl_wr_end) { - D[d_gl_wr] = sh[d_sh_rd]; - d_gl_wr += d_gl_wr_delta; - d_sh_rd += d_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (last) { - if (s_tok_sh_wr_pred) { - cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]); - } - if (s_ch_sh_wr_pred) { - cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]); - } - cp_async_fence(); - } - thread_block_reduce(); - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - frag_s_tok[i][0] = - *reinterpret_cast(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]); - frag_s_tok[i][1] = *reinterpret_cast( - &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]); - } - reinterpret_cast(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0]; - reinterpret_cast(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1]; - reinterpret_cast(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8]; - reinterpret_cast(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x; - s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \ - prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D, - void* s_tok, void* s_ch, void* s_group, int prob_m, - int prob_n, int prob_k, void* workspace, - int groupsize = -1, int dev = 0, cudaStream_t stream = 0, - int thread_k = -1, int thread_n = -1, int sms = -1, - int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - int4* D_ptr = (int4*)D; - const float* s_tok_ptr = (const float*)s_tok; - const int4* s_ch_ptr = (const int4*)s_ch; - const int4* s_group_ptr = (const int4*)s_group; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par; - D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - s_tok_ptr += 16 * thread_m_blocks * par; - } -} -} // anonymous namespace - -torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, - torch::Tensor const& b_q_weight, - torch::Tensor const& s_tok, - torch::Tensor const& s_ch, - torch::Tensor const& s_group, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - TORCH_CHECK(size_m == s_tok.numel(), - "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(tile_size)); - TORCH_CHECK( - (size_k / tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + - ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size)); - - int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0); - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify N - TORCH_CHECK(s_ch.numel() == size_n, - "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) + - ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(tile_size)); - if (groupsize != -1) { - TORCH_CHECK(s_group.size(1) == size_n, - "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - size_k % s_group.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by s_group.size(0) = " + str(s_group.size(0))); - } - - int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "Shape mismatch: size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify s_tok device, strides and dtype - TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU"); - TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous"); - TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32"); - - // Verify s_ch device, strides and dtype - TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU"); - TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous"); - TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32"); - - // Verify s_group device, strides and dtype - TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU"); - TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous"); - TORCH_CHECK(s_group.dtype() == torch::kFloat16, - "s_group's dtype is not float16"); - - // Verify workspace size - TORCH_CHECK(size_n % min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(min_thread_n)); - int min_workspace_size = (size_n / min_thread_n) * max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c); - - // Alloc D matrix - auto options_d = - torch::TensorOptions().dtype(torch::kFloat16).device(a.device()); - torch::Tensor d = torch::empty({size_m, size_n}, options_d); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - int dev = a.get_device(); - marlin_qqq_cuda( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(), - s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par); - - return d; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_qqq_gemm", &marlin_qqq_gemm); -} diff --git a/csrc/quantization/vectorization_utils.cuh b/csrc/quantization/vectorization_utils.cuh index 8aa0147df6ba8..98b491b7e23fc 100644 --- a/csrc/quantization/vectorization_utils.cuh +++ b/csrc/quantization/vectorization_utils.cuh @@ -41,8 +41,10 @@ __device__ inline void vectorize_with_alignment( for (int i = tid; i < num_vec; i += stride) { vout_t tmp; - vec_op(tmp, v_in[i]); - v_out[i] = tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st } return; } @@ -71,8 +73,10 @@ __device__ inline void vectorize_with_alignment( // 2. vectorize the main part for (int i = tid; i < num_vec; i += stride) { vout_t tmp; - vec_op(tmp, v_in[i]); - v_out[i] = tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st } // 3. handle the tail @@ -125,7 +129,8 @@ __device__ inline void vectorize_read_with_alignment(const InT* in, int len, auto* v_in = reinterpret_cast(in); for (int i = tid; i < num_vec; i += stride) { - vec_op(v_in[i]); + vin_t tmp = v_in[i]; + vec_op(tmp); } return; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a547baec50d6a..7ae054dc19fbd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,6 +130,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); + ops.def( + "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " + "limit=7.0) " + "-> ()"); + ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul); + // GELU implementation used in GPT-2. ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCUDA, &gelu_new); @@ -207,21 +213,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization ops #ifndef USE_ROCM - // Quantized GEMM for AQLM. - ops.def( - "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, " - "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) " - "-> Tensor", - {stride_tag}); - ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); - - // Decompression method for AQLM. - ops.def( - "aqlm_dequant(Tensor codes, Tensor codebooks, " - "int[] codebook_partition_sizes) -> Tensor", - {stride_tag}); - ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); - // Quantized GEMM for AWQ. ops.def( "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " @@ -250,14 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // custom types: // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA - // Marlin (Dense) Optimized Quantized GEMM for GPTQ. - ops.def( - "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " - "Tensor", - {stride_tag}); - // conditionally compiled so impl in source file - // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " @@ -326,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file + + // CUTLASS w4a8 GEMM + ops.def( + "cutlass_w4a8_mm(" + " Tensor A," + " Tensor B," + " Tensor group_scales," + " int group_size," + " Tensor channel_scales," + " Tensor token_scales," + " ScalarType? out_type," + " str? maybe_schedule" + ") -> Tensor", + {stride_tag}); + // pack scales + ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); + // encode and reorder weight matrix + ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. @@ -362,15 +365,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM - // marlin_qqq_gemm for QQQ. - ops.def( - "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " - "Tensor s_tok, Tensor s_ch, Tensor s_group, " - "Tensor! workspace, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor", - {stride_tag}); - // conditionally compiled so impl registration is in source file - // CUTLASS nvfp4 block scaled GEMM ops.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," @@ -449,6 +443,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); + // A function that computes problem sizes for each expert's multiplication + // used by the two mms called from fused MoE operation. It takes topk_ids as + // an input, and computes problem_sizes1 and problem_sizes2 only. + ops.def( + "get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, " + " Tensor! problem_sizes1, " + " Tensor! problem_sizes2, " + " int num_experts, int n, int k, " + " Tensor? blockscale_offsets) -> ()", + {stride_tag}); + ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, + &get_cutlass_moe_mm_problem_sizes); + // A function that computes data required to run fused MoE with w8a8 grouped // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs // as an input, and computes expert_offsets (token start indices of each @@ -685,11 +692,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); - // Gather cache blocks from src_cache to dst. + // Gather cache blocks from src_cache to dst, dequantizing from + // src_cache's dtype to dst's dtype if necessary. cache_ops.def( - "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " - "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); - cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); + "gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, " + " Tensor block_table, Tensor cu_seq_lens, " + " int batch_size, " + " str kv_cache_dtype, " + " Tensor scale, Tensor? seq_starts) -> ()"); + cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, + &gather_and_maybe_dequant_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/docker/Dockerfile b/docker/Dockerfile index 66a6e6fd6f67d..2e272cbca8417 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -139,21 +139,6 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ WORKDIR /workspace # install build and runtime dependencies - -# arm64 (GH200) build follows the practice of "use existing pytorch" build, -# we need to install torch and torchvision from the nightly builds first, -# pytorch will not appear as a vLLM dependency in all of the following steps -# after this step -RUN --mount=type=cache,target=/root/.cache/uv \ - if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - uv pip install --system \ - --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ - "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \ - uv pip install --system \ - --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ - --pre pytorch_triton==3.3.0+gitab727c40; \ - fi - COPY requirements/common.txt requirements/common.txt COPY requirements/cuda.txt requirements/cuda.txt RUN --mount=type=cache,target=/root/.cache/uv \ @@ -234,6 +219,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \ && sccache --show-stats; \ fi +ARG vllm_target_device="cuda" +ENV VLLM_TARGET_DEVICE=${vllm_target_device} ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/uv \ @@ -385,31 +372,45 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt -# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel. -ARG FLASHINFER_GIT_REF="v0.2.11" +# Keep this in sync with "flashinfer" extra in setup.py +ARG FLASHINFER_GIT_REF="v0.2.14.post1" +# Flag to control whether to compile FlashInfer AOT kernels +# Set to "true" to enable AOT compilation: +# docker build --build-arg FLASHINFER_AOT_COMPILE=true ... +ARG FLASHINFER_AOT_COMPILE=false RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' . /etc/environment git clone --depth 1 --recursive --shallow-submodules \ --branch ${FLASHINFER_GIT_REF} \ ${FLASHINFER_GIT_REPO} flashinfer - # Exclude CUDA arches for older versions (11.x and 12.0-12.7) - # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. - if [[ "${CUDA_VERSION}" == 11.* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" - elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" - else - # CUDA 12.8+ supports 10.0a and 12.0 - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" - fi - echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}" - # Needed to build AOT kernels pushd flashinfer - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer.aot - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - uv pip install --system --no-build-isolation --force-reinstall --no-deps . + if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then + # Exclude CUDA arches for older versions (11.x and 12.0-12.7) + # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. + if [[ "${CUDA_VERSION}" == 11.* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" + elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" + else + # CUDA 12.8+ supports 10.0a and 12.0 + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" + fi + echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" + # Build AOT kernels + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer.aot + # Install with no-build-isolation since we already built AOT kernels + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + uv pip install --system --no-build-isolation . \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + # Download pre-compiled cubins + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." + else + echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode" + uv pip install --system . \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + fi popd rm -rf flashinfer BASH @@ -431,31 +432,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # Install DeepGEMM from source -ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" -RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' - . /etc/environment - CUDA_MAJOR="${CUDA_VERSION%%.*}" - CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" - CUDA_MINOR="${CUDA_MINOR%%.*}" - if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then - git clone --recursive --shallow-submodules \ - ${DEEPGEMM_GIT_REPO} deepgemm - echo "🏗️ Building DeepGEMM" - pushd deepgemm - git checkout ${DEEPGEMM_GIT_REF} - # Build DeepGEMM - # (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) - rm -rf build dist - rm -rf *.egg-info - python3 setup.py bdist_wheel - uv pip install --system dist/*.whl - popd - rm -rf deepgemm - else - echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" - fi -BASH +COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh +RUN --mount=type=cache,target=/root/.cache/uv \ + VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \ + && rm /tmp/install_deepgemm.sh + +# Install EP kernels(pplx-kernels and DeepEP), NixL +COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh +COPY tools/install_nixl.sh install_nixl.sh +ENV CUDA_HOME=/usr/local/cuda +RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \ + && bash install_python_libraries.sh \ + && bash install_nixl.sh --force #################### vLLM installation IMAGE #################### diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 4f40f32a39f26..f164857325043 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install lm-eval[api]==0.4.4 \ + && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ && python3 -m pip install pytest-shard # ----------------------- diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 4e89bb3057c5e..9270b48c54d4b 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -16,7 +16,7 @@ ENV LANG=C.UTF-8 \ RUN microdnf install -y \ which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile && \ microdnf clean all # Python Installation @@ -136,6 +136,71 @@ RUN --mount=type=cache,target=/root/.cache/uv \ mkdir -p /tmp/hf-xet/dist && \ cp dist/*.whl /tmp/hf-xet/dist/ +# Build numba +FROM python-install AS numba-builder + +ARG MAX_JOBS +ARG NUMBA_VERSION=0.61.2 + +WORKDIR /tmp + +# Clone all required dependencies +RUN --mount=type=cache,target=/root/.cache/uv \ + microdnf install ninja-build gcc gcc-c++ -y && \ + git clone --recursive https://github.com/llvm/llvm-project.git -b llvmorg-15.0.7 && \ + git clone --recursive https://github.com/numba/llvmlite.git -b v0.44.0 && \ + git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \ + cd llvm-project && mkdir build && cd build && \ + uv pip install 'cmake<4' setuptools numpy && \ + export PREFIX=/usr/local && CMAKE_ARGS="${CMAKE_ARGS} -DLLVM_ENABLE_PROJECTS=lld;libunwind;compiler-rt" \ + CFLAGS="$(echo $CFLAGS | sed 's/-fno-plt //g')" \ + CXXFLAGS="$(echo $CXXFLAGS | sed 's/-fno-plt //g')" \ + CMAKE_ARGS="${CMAKE_ARGS} -DFFI_INCLUDE_DIR=$PREFIX/include" \ + CMAKE_ARGS="${CMAKE_ARGS} -DFFI_LIBRARY_DIR=$PREFIX/lib" \ + cmake -DCMAKE_INSTALL_PREFIX="${PREFIX}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_LIBRARY_PATH="${PREFIX}" \ + -DLLVM_ENABLE_LIBEDIT=OFF \ + -DLLVM_ENABLE_LIBXML2=OFF \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_ENABLE_TERMINFO=OFF \ + -DLLVM_INCLUDE_BENCHMARKS=OFF \ + -DLLVM_INCLUDE_DOCS=OFF \ + -DLLVM_INCLUDE_EXAMPLES=OFF \ + -DLLVM_INCLUDE_GO_TESTS=OFF \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_UTILS=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DLLVM_UTILS_INSTALL_DIR=libexec/llvm \ + -DLLVM_BUILD_LLVM_DYLIB=OFF \ + -DLLVM_LINK_LLVM_DYLIB=OFF \ + -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD=WebAssembly \ + -DLLVM_ENABLE_FFI=ON \ + -DLLVM_ENABLE_Z3_SOLVER=OFF \ + -DLLVM_OPTIMIZED_TABLEGEN=ON \ + -DCMAKE_POLICY_DEFAULT_CMP0111=NEW \ + -DCOMPILER_RT_BUILD_BUILTINS=ON \ + -DCOMPILER_RT_BUILTINS_HIDE_SYMBOLS=OFF \ + -DCOMPILER_RT_BUILD_LIBFUZZER=OFF \ + -DCOMPILER_RT_BUILD_CRT=OFF \ + -DCOMPILER_RT_BUILD_MEMPROF=OFF \ + -DCOMPILER_RT_BUILD_PROFILE=OFF \ + -DCOMPILER_RT_BUILD_SANITIZERS=OFF \ + -DCOMPILER_RT_BUILD_XRAY=OFF \ + -DCOMPILER_RT_BUILD_GWP_ASAN=OFF \ + -DCOMPILER_RT_BUILD_ORC=OFF \ + -DCOMPILER_RT_INCLUDE_TESTS=OFF \ + ${CMAKE_ARGS} -GNinja ../llvm \ + + && ninja install . && \ + # build llvmlite + cd ../../llvmlite && python setup.py bdist_wheel && \ + cd ../numba && \ + if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \ + sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ + fi && python setup.py bdist_wheel + + # Final build stage FROM python-install AS vllm-cpu ARG PYTHON_VERSION @@ -163,23 +228,30 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ --mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \ --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ + --mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \ + --mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \ sed -i '/^torch/d' requirements/build.txt && \ - ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ - VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ - HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \ - TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ + ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \ + VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \ + HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl) && \ + TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \ + LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \ + NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ $HF_XET_WHL_FILE \ $TORCH_WHL_FILE \ + $LLVM_WHL_FILE \ + $NUMBA_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ - -r requirements/cpu.txt + -r requirements/cpu.txt + # Build and install vllm RUN --mount=type=cache,target=/root/.cache/uv \ - VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ + VLLM_TARGET_DEVICE=cpu VLLM_CPU_MOE_PREPACK=0 python setup.py bdist_wheel && \ uv pip install "$(echo dist/*.whl)[tensorizer]" # setup non-root user for vllm @@ -196,4 +268,3 @@ WORKDIR /home/vllm # Set the default entrypoint ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] - diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu index 2190151369761..ca2d7833c1efa 100644 --- a/docker/Dockerfile.tpu +++ b/docker/Dockerfile.tpu @@ -7,7 +7,8 @@ WORKDIR /workspace/vllm # Install some basic utilities RUN apt-get update && apt-get install -y \ git \ - ffmpeg libsm6 libxext6 libgl1 + ffmpeg libsm6 libxext6 libgl1 && \ + rm -rf /var/lib/apt/lists/* # Build vLLM. COPY . . @@ -16,6 +17,9 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi # Remove existing versions of dependencies +# TODO: These packages will remain as dead weight in the Docker image layers. +# We should find a way to build the image without uninstalling these. +# Consider using a different base image. RUN pip uninstall -y torch torch_xla torchvision ENV VLLM_TARGET_DEVICE="tpu" @@ -23,9 +27,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 -m pip install \ -r requirements/tpu.txt -RUN python3 -m pip install -e . + +RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e . # install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils +RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e tests/vllm_test_utils CMD ["/bin/bash"] diff --git a/docs/api/README.md b/docs/api/README.md index 327472df1d52c..57142e8f5625d 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -77,6 +77,7 @@ Internal data structures. - [vllm.multimodal.inputs.MultiModalFieldElem][] - [vllm.multimodal.inputs.MultiModalFieldConfig][] - [vllm.multimodal.inputs.MultiModalKwargsItem][] +- [vllm.multimodal.inputs.MultiModalKwargsItems][] - [vllm.multimodal.inputs.MultiModalKwargs][] - [vllm.multimodal.inputs.MultiModalInputs][] diff --git a/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png b/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png new file mode 100644 index 0000000000000..185f61e6a3ede Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/full_attn.png b/docs/assets/design/hybrid_kv_cache_manager/full_attn.png new file mode 100644 index 0000000000000..30eade5c7051c Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/full_attn.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png b/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png new file mode 100644 index 0000000000000..bcffc27a71649 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/overview.png b/docs/assets/design/hybrid_kv_cache_manager/overview.png new file mode 100644 index 0000000000000..ac80581f491da Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/overview.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png b/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png new file mode 100644 index 0000000000000..10aa6146dc7ab Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png differ diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 36232e6ad96cc..221a7bd96213f 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -2,6 +2,8 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) +- [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 058eba5fe0b1e..efda9c8e019eb 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) +- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 2eeb8ad25de5f..b11ccb5c00273 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -48,7 +48,7 @@ You can tune the performance by adjusting `max_num_batched_tokens`: - Smaller values (e.g., 2048) achieve better inter-token latency (ITL) because there are fewer prefills slowing down decodes. - Higher values achieve better time to first token (TTFT) as you can process more prefill tokens in a batch. -- For optimal throughput, we recommend setting `max_num_batched_tokens > 8096` especially for smaller models on large GPUs. +- For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs. - If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes). ```python @@ -129,6 +129,56 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. +### Batch-level DP for Multi-Modal Encoders + +By default, TP is used to shard the weights of multi-modal encoders just like for language decoders, +in order to reduce the memory and compute load on each GPU. + +However, since the size of multi-modal encoders is very small compared to language decoders, +there is relatively little gain from TP. On the other hand, TP incurs significant communication +overhead because of all-reduce being performed after every layer. + +Given this, it may be advantageous to instead shard the batched input data using TP, essentially +performing batch-level DP. This has been shown to improve the throughput by around 10% for +`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations, +batch-level DP can provide another 40% increase to throughput compared to regular TP. + +Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank, +there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already. + +You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example: + +```python +from vllm import LLM + +llm = LLM( + model="Qwen/Qwen2.5-VL-72B-Instruct", + tensor_parallel_size=4, + # When mm_encoder_tp_mode="data", + # the vision encoder uses TP=4 (not DP=1) to shard the input data, + # so the TP size becomes the effective DP size. + # Note that this is independent of the DP size for language decoder which is used in expert parallel setting. + mm_encoder_tp_mode="data", + # The language decoder uses TP=4 to shard the weights regardless + # of the setting of mm_encoder_tp_mode +) +``` + +!!! important + Batch-level DP is not to be confused with API request-level DP + (which is instead controlled by `data_parallel_size`). + +Batch-level DP needs to be implemented on a per-model basis, +and enabled by setting `supports_encoder_tp_data = True` in the model class. +Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature. + +Known supported models: + +- Llama4 () +- MiniCPM-V-4 () +- Qwen2.5-VL () +- Step3 () + ## Input Processing ### Parallel Processing @@ -149,21 +199,41 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 !!! note API server scale-out is only available for online inference. +!!! warning + By default, 8 CPU threads are used in each API server to load media items (e.g. images) + from request data. + + If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT` + to avoid CPU resource exhaustion. + !!! note - [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled + API server scale-out disables [multi-modal IPC caching](#ipc-caching) because it requires a one-to-one correspondance between API and engine core processes. + This does not impact [multi-modal processor caching](#processor-caching). + ## Multi-Modal Caching -### Processor Cache - -By default, the multi-modal processor cache is enabled to avoid repeatedly processing -the same multi-modal inputs via Hugging Face `AutoProcessor`, +Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, which commonly occurs in multi-turn conversations. -You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` -(default 4 GiB per API process + 4 GiB per engine core process). -If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. +### Processor Caching + +Multi-modal processor caching is automatically enabled +to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`. + +### IPC Caching + +Multi-modal IPC caching is automatically enabled when +there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes, +to avoid repeatedly transferring the same multi-modal inputs between them. + +### Configuration + +You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). + +If you do not benefit much from the cache, you can disable both IPC +and processor caching completely via `mm_processor_cache_gb=0`. Examples: @@ -176,3 +246,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) ``` + +### Cache Placement + +Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: + +| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------| +| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | +| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| ❌ | ❌ | N/A | N/A | `0` | + +K: Stores the hashes of multi-modal items +V: Stores the processed tensor data of multi-modal items diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md index a93435ed71b50..e456077e04958 100644 --- a/docs/configuration/tpu.md +++ b/docs/configuration/tpu.md @@ -45,32 +45,32 @@ This initial compilation time ranges significantly and is impacted by many of th ### Optimize based on your data -#### max model len vs. most model len +#### max-model-len vs. most-model-len ![most_model_len](../assets/design/tpu/most_model_len.png) -If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most model len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. +If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most-model-len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. For example, 1% requests are 32k length and 99% requests are 2k length. You can pass 32k into `--max-model-len 32768` and use `VLLM_TPU_MOST_MODEL_LEN=2048`. -The requests get subdivided into max-model-len and most-model-len categories, for the latter category, we can gain better performance since the server can process more requests at a time. +The requests get subdivided into max-model-len and most-model-len categories, for the latter category, you can gain better performance since the server can process more requests at a time. #### Padding -For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128: 128, 256, etc. +For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128 (e.g., 128, 256, etc.) -The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about tpu padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: +The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about TPU padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: -1) the default exponential padding (pad to the nearest power of 2) -2) bucket padding (pad to the nearest linearly increasing bucket). +1. the default exponential padding (pad to the nearest power of 2) +2. bucket padding (pad to the nearest linearly increasing bucket). When using bucket padding, the buckets start from 16, end at max_model_len, and increment by `VLLM_TPU_BUCKET_PADDING_GAP`. For example, max_model_len=512, padding_gap=64, the buckets will be [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]. -The fewer tokens we pad, the less unnecessary computation TPU does, the better performance we can get. For example, if num_tokens=300, with exponential padding, we pad to 512, with the bucket_padding above, we pad to 320. +The fewer tokens you pad, the less unnecessary computation TPU does, the better performance you can get. For example, if num_tokens=300, with exponential padding, you pad to 512, with the bucket_padding above, you pad to 320. -However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compilaed graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. +However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compiled graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. #### Quantization diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 64a48be32645a..76d0f067fd452 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -629,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -778,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index e62a33b2085ca..0b41e73b030cc 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -18,7 +18,7 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 - Download and install [Anything LLM desktop](https://anythingllm.com/desktop). -- On the bottom left of open settings, AI Prooviders --> LLM: +- On the bottom left of open settings, AI Providers --> LLM: - LLM Provider: Generic OpenAI - Base URL: http://{vllm server host}:{vllm server port}/v1 - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` diff --git a/docs/deployment/frameworks/dstack.md b/docs/deployment/frameworks/dstack.md index 23dc58c974ed8..fe4d87f78f2aa 100644 --- a/docs/deployment/frameworks/dstack.md +++ b/docs/deployment/frameworks/dstack.md @@ -9,7 +9,7 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/), To install dstack client, run: ```bash -pip install "dstack[all] +pip install dstack[all] dstack server ``` diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 3ef1232051b07..202e9c1caf113 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -133,7 +133,7 @@ class FusedMoEModularKernel: Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example, * PplxPrepareAndFinalize type is backed by Pplx All2All kernels, -* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughtput All2All kernels, and +* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and * DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels. #### Step 1: Add an All2All manager @@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking ### FusedMoEModularKernel Initialization -`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, +`FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, +* maybe_make_prepare_finalize, * select_gemm_impl, and * init_prepare_finalize +#### maybe_make_prepare_finalize + +The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. +Please refer to the implementations in, + +* `ModelOptNvFp4FusedMoE` + #### select_gemm_impl The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object. @@ -190,7 +198,7 @@ Please refer to the implementations in, * `CompressedTensorsW8A8Fp8MoECutlassMethod` * `Fp8MoEMethod` * `ModelOptNvFp4FusedMoE` -dervied classes. +derived classes. #### init_prepare_finalize @@ -218,7 +226,7 @@ Doing this will add the new implementation to the test suite. The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` -As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked +As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked with incompatible types, the script will error. ### How To Profile diff --git a/docs/design/hybrid_kv_cache_manager.md b/docs/design/hybrid_kv_cache_manager.md new file mode 100644 index 0000000000000..8f17b473adc08 --- /dev/null +++ b/docs/design/hybrid_kv_cache_manager.md @@ -0,0 +1,245 @@ +# Hybrid KV Cache Manager + +!!! warning + This document was written based on commit [458e74](https://github.com/vllm-project/vllm/commit/458e74eb907f96069e6d8a4f3c9f457001fef2ea). This feature is still in its early stage and things may change. + +## What is a hybrid model? + +Many recent "hybrid" LLMs combine multiple attention types within one model. For example: + +1. Sliding window attention (sw) + full attention (full): gpt-oss, Gemma 2/3, Ministral, cohere, etc. +2. Mamba + full: Bamba, Jamba, Minimax, etc. +3. Local chunked attention + full: Llama4 + +To serve these models efficiently, our [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] must: + +1. Allocate different slots to different layer type, for example: + - Full attention layers: reserve slots for **all** tokens. + - Sliding window layers: reserve slots only for the most recent **`sliding_window_size`** tokens. +2. Support layer-specific prefix-cache rules, for example: + - Full attention: a cache hit prefix requires **all** tokens remain in the KV cache. + - Sliding window: a cache hit prefix only requires the last **`sliding_window_size`** tokens remain in the KV cache. + +## Definitions + +1. **kv hidden size**: The number of bytes to store one token's KV cache for a single layer. +2. **block**: the memory reserved for kv cache are divided into multiple *blocks* with the same *page size* (defined below) +3. **block size**: number of tokens inside a block +4. **page size**: the physical memory size of a block, defined as: + + $$ + \text{num_layers} \times \text{block_size} \times \text{kv_hidden_size} + $$ + + `num_layers` doesn't mean the total number of layers in the model. The exact number depends on the context in this doc. + + !!! note + This is different from `KVCacheSpec.page_size_bytes` in the code, which is defined as: + + $$ + \text{block_size} \times \text{kv_hidden_size} + $$ + +## Allocation + +### High level idea + +We use a single memory pool for all layer types. The memory pool is split into multiple blocks with the same page size. [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates different numbers of blocks to different layers according to its attention type. + +The core challenge is ensuring every layer type uses the same **page size**. For full-attention-only models, the page size is straightforward, defined as: + +$$ +\text{page_size} = \text{block_size} \times \text{num_hidden_layers} \times \text{kv_hidden_size} +$$ + +However, in hybrid models, `num_hidden_layers` varies by attention type, which would normally produce mismatched page sizes. The cases below show how we unify them. + +### Case 1: toy model + +Let's start with a toy example: a model has 1 full attention layer and 3 sliding window attention layers. All layers have the same `kv_hidden_size`. + +We let each block to hold `block_size` tokens for one layer, so: + +$$ +\text{page_size} = \text{kv_hidden_size} \times \text{block_size} +$$ + +[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates a different number of blocks to each layer. + +This case is only a toy example. For real models, please refer to the following cases. + +### Case 2: same `kv_hidden_size` and a regular pattern + +When the model has more layers, e.g., 20 sliding window attention layers and 10 full attention layers with the same `kv_hidden_size`. Calling the allocator once per layer (30 calls) is OK but becomes inefficient. As a solution, we group the allocation of layers that need the same number of blocks to reduce the number of calls. + +The grouping is feasible because there is usually a beautiful ratio between the number of different types of layers. For example: + +- Gemma-2: 1 sw : 1 full +- Llama 4: 3 local : 1 full + +Our example can be regarded as 2 sw : 1 full. We can allocate blocks as if there are 2 sw and 1 full in the model, and repeat the result by 10 times to generate the `block_ids` for the 30 layers. The page size becomes: + +$$ +10 \times \text{kv_hidden_size} \times \text{block_size} +$$ + +Assume `block_size` 16, sliding window size 32, request length 112, then for the above example model, we need to allocate 11 blocks (0-6 for full, 7-8 for sw group 1, 9-10 for sw group 2). + +![Allocation Result](../assets/design/hybrid_kv_cache_manager/basic_grouping_example.png) + +Here, "/" denotes no block needed (sliding‑window layers don't need slots for early tokens). + +See the formal definition below. The layers are divided into multiple *KV Cache Groups* so that there is: + +1. **Identical attention type inside each group**: Each group only contains layers with the same attention type and thus need the same number of blocks for a given request. This enables layers in the same group share the same block ids without memory waste. +2. **Identical page size across groups**: Because our memory pool only have one page size. + +Our example model is divided into 3 KV cache groups: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) + +Obviously, it satisfies rule 1. For rule 2, all 3 groups have + +$$ +10 \times \text{kv_hidden_size} \times \text{block_size} +$$ + +as their page size. + +### Case 3: same `kv_hidden_size` and no regular pattern + +Unfortunately, not all models have such a beautiful ratio, and approach in Case 2 will produce too many small groups. For example, Gemma-3-27b has 52 sliding window attention layers and 10 full attention layers. With the constraints in case 2, it would be 26 sliding window groups and 5 full attention groups, each contains 2 layers. The allocation is still inefficient. To reduce the number of kv cache groups, we group layers using the smallest layer count among all attention types. For example, min(52, 10)=10 layers per group in Gemma-3-27b. Then the grouping result is: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) +- ... +- Group 6: 10 sliding window attention layers (sw.40 - sw.49) +- Group 7: 2 sliding window attention layers (sw.50 - sw.51) and 8 padding layers + +We will update this algorithm if this heuristic leads to a bad result when a new model comes out (e.g., 20 full + 30 sw, the group size should be 10 instead of 20). + +This case happens in Gemma-3 series models, and models in case 2 but with eagle speculative decoding which introduce one full attention layer. The solution has some memory waste and is not perfect. Please report any cases where padding overhead becomes unacceptable so we can refine the algorithm. + +### Case 4: different `kv_hidden_size` (mainly hybrid mamba models) + +Some architectures (e.g., Bamba, Jamba, Minimax) interleave standard attention layers with Mamba layers, where each Mamba layer's state size per token can be much larger than the attention layers' `kv_hidden_size`. Because we only support a single page size across all groups, we must reconcile these differing hidden sizes. + +The current algorithm is: + +1. Increase the `block_size` of attention layers until + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \ge \text{state_size}_{\text{mamba}} + $$ +2. Pad the mamba state per layer to + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} + $$ +3. Apply the grouping strategy in case 3. + +!!! note + This can lead to more than 400 `block_size` for attention layers, which is too large. Another padding strategy is to increase `block_size` until + + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \times \text{num_attn_layers} \ge \text{state_size}_{\text{mamba}} + $$ + + This padding strategy is still a work in progress. + +### Case 5: KV sharing + +KV sharing refers to a layer using the KV cache of another layer, e.g., gemma-3n. +In these models, [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] ignores all layers with kv sharing and only allocates KV cache for layers that need kv cache, and some patches are made in model runner to apply the allocation result to kv sharing layers. + +## Prefix caching + +For simplicity, we assume `block_size=1` in this section. + +### High level idea + +The block pool uses a dict similar to `tuple(block_hash, group_id) -> block` to catch the full blocks. That means the same tokens of different groups are cached and evicted independently. + +When a new request comes in, we check the cache hit prefix of each group, and return the intersection of these groups as the cached prefix of the request. See below for the detailed algorithm for checking the cache hit of one group & performing the intersection. + +### Case 0: full attention only models + +For full attention layers, blocks are allocated for all tokens in the request. For details on the underlying design, see [Prefix Caching](prefix_caching.md) + +To find the longest cache hit prefix of a request, we enumerate from left (the first block) to right (the last block), checking whether the block is cached, and exit when cache misses. For example, we will return the first 7 tokens (0-6) as the cache hit prefix in the below example (blue blocks are cached): + +![Prefix Caching of Full Attention](../assets/design/hybrid_kv_cache_manager/full_attn.png) + +### Case 1: sliding window attention only models + +For sliding window attention layers, a naive implementation for memory allocation is to allocate `sliding_window_size` blocks and fill in the blocks in a round-robin way. But this naive implementation is not compatible with prefix caching so we didn't pick this design. In vLLM, we allocate different blocks for different tokens and free blocks that are outside the sliding window. + +For a new request, the cache hit prefix only requires the last `sliding_window_size - 1` tokens being cached. +Let's say `sliding_window_size = 4` and `block_size = 1`, and the request is a 15-token prompt (blue blocks are cached): + +![Prefix Caching of Sliding Window Attention](../assets/design/hybrid_kv_cache_manager/sw_attn.png) + +There are 3 possible cache hit prefixes: + +- cache hit length 5, compute prefill with [2, 3, 4] → [5, 6, …, 14] +- cache hit length 6, compute prefill with [3, 4, 5] → [6, 7, …, 14] +- cache hit length 14, compute prefill with [11, 12, 13] → [14] (most efficient) + +We can check the cache hit from right to left, and early exit when we find a match.This is opposite from full attention, where we check from left to right and early exit when the match fails. One potential cons (compared to full attention) is that we end up iterating over the entire list of tokens when there's no match, which is often a common case. This could potentially cause non-negligible overheads, but fine with full + swa, as discussed below. + +### Case 2: sliding window attention + full attention models + +The first problem is how to find the cache hit prefix. We need to "intersect" the cache hits of global and sliding window attention layers by: + +1. Get the longest cache hit for full attention (scanning from left to right) +2. Get the longest cache hit for sliding window attention that is within that length. Implemented by checking cache hits from right to left starting from the cache hit length of full attention. + +It can be ensured that the resulting cache hit of sliding window attention layers is also a cache hit of full attention layers. This is more efficient than finding all possible prefixes of each group and doing the intersection, because our approach can exit early if there is no cache hit. + +The algorithm applies to models with exactly two attention types full attention + X, where X can be an arbitrary efficient attention algorithm like sliding window, llama 4 local attention, and mamba. It doesn't support models without full attention layers, and models with more than 2 types of attention. This is enough for most hybrid models at the moment of writing this doc. + +The second question is the cache eviction policy. For now, we use one LRU queue for all kv cache groups. The blocks are added to the LRU queue when freed, either because the request is finished or the block is out of the sliding window. + +### Case 3: mamba models + +The prefix caching support of the mamba model is work in progress. Once implemented, models with mamba layer + full attention layer can be supported via the full attention + X algorithm in case 2. + +## Implementation + +### Overview + +![Overview of Hybrid KV Cache Manager](../assets/design/hybrid_kv_cache_manager/overview.png) + +The `KVCacheManager` is organized into 3 layers: + +- **[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager]**: The interface between the scheduler and kv cache management system. +- **[KVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinator]**: coordinate per-group SingleTypeKVCacheManagers to generate the allocation result of a request. Depending on the model's configuration, one of these coordinators is chosen: + - **[KVCacheCoordinatorNoPrefixCache][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinatorNoPrefixCache]**: Used when prefix caching is disabled. + - **[UnitaryKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.UnitaryKVCacheCoordinator]**: If only one KV cache group. The prefix caching logic is simplified as no intersection is needed. + - **[HybridKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.HybridKVCacheCoordinator]**: Handles exactly two KV cache groups (must include one full‑attention group plus one other efficient‑attention group). Other cases are not implemented. You can disable prefix caching to use the KVCacheCoordinatorNoPrefixCache. +- **[SingleTypeKVCacheManager][vllm.v1.core.single_type_kv_cache_manager.SingleTypeKVCacheManager]**: Each instance manages allocation and prefix caching for one KV cache group, implementing the attention‑type–specific logic (e.g., full attention, sliding window, Mamba). + +The blue box in the above figure shows the case with 10 full attention layers and 20 sliding window attention layers, thus: + +- use `HybridKVCacheCoordinator` +- use 1 `FullAttentionManager` and 2 `SlidingWindowManager` for the 3 `KVCacheGroup`s. + +### Memory Layout + +For a model with n `KVCacheGroup`s, each with m layers, we allocate m buffers. Each buffer is shared by n layers, one from each group. + +The following figure is for a model with 10 full attention layers (full.0 - full.9) and 20 sliding window attention layers (sw.0-sw.19). It follows "case 2" in "Allocation" section and is divided into 3 groups: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) + +And for a request, we allocate 11 blocks with `block_id` 0-6 to group 0, 7-8 to group 1, and 9-10 to group 2. + +With such an example, the physical memory is divided into 10 buffers (`KVCacheTensor` 0 - `KVCacheTensor` 9). Each buffer is shared by 3 layers (e.g., `KVCacheTensor` 0 is shared by full.0 from group 0, sw.0 from group 1, and sw.10 from group 2) and is divided into pieces with size `block_size * kv_hidden_size`. The KV cache of these 3 attention layers are saved to different pieces of the buffer based on the allocated `block_ids`: + +![Example Memory Layout](../assets/design/hybrid_kv_cache_manager/memory_layout.png) + +!!! note + One logic "block" is mapped to 10 pieces in the 10 buffers of the physical memory. diff --git a/docs/design/metrics.md b/docs/design/metrics.md index b01838883f31e..b24364247b3f8 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -565,7 +565,7 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_num_emitted_tokens_total` (Counter) There is a PR under review () to add "prompt lookup (ngram)" -seculative decoding to v1. Other techniques will follow. We should +speculative decoding to v1. Other techniques will follow. We should revisit the v0 metrics in this context. !!! note diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 06ebd77258582..247072d1cb275 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`. There are other miscellaneous places hard-coding the use of `spawn`: -- +- - Related PRs: diff --git a/docs/design/paged_attention.md b/docs/design/paged_attention.md index fb991a35caf30..d87b2a639df12 100644 --- a/docs/design/paged_attention.md +++ b/docs/design/paged_attention.md @@ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle a whole block of value tokens. And each `accs` in each thread contains 8 elements that accumulated at 8 different head positions. For the thread 0, the `accs` variable will have 8 elements, which -are 0th, 32th … 224th elements of a value head that are accumulated +are 0th, 32nd … 224th elements of a value head that are accumulated from all assigned 8 tokens. ## LV diff --git a/docs/examples/README.md b/docs/examples/README.md index 34e4dfd408a20..3cf93027f4209 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -2,6 +2,6 @@ vLLM's examples are split into three categories: -- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference/) -- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving/) -- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others/) +- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference) +- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving) +- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others) diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index cdd32924b5668..9d51f9cf52f50 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -216,7 +216,7 @@ Instead of NumPy arrays, you can also pass `'torch.Tensor'` instances, as shown from vllm import LLM, SamplingParams from qwen_vl_utils import process_vision_info - model_path = "Qwen/Qwen2.5-VL-3B-Instruct/" + model_path = "Qwen/Qwen2.5-VL-3B-Instruct" video_path = "https://content.pexels.com/videos/free-videos.mp4" llm = LLM( diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md index e18c128f30fc9..4605ba7781ed4 100644 --- a/docs/features/quantization/README.md +++ b/docs/features/quantization/README.md @@ -4,7 +4,6 @@ Quantization trades off model precision for smaller memory footprint, allowing l Contents: -- [Supported Hardware](supported_hardware.md) - [AutoAWQ](auto_awq.md) - [AutoRound](auto_round.md) - [BitsAndBytes](bnb.md) @@ -19,3 +18,50 @@ Contents: - [AMD Quark](quark.md) - [Quantized KV Cache](quantized_kvcache.md) - [TorchAO](torchao.md) + +## Supported Hardware + +The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: + + + +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | +| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | + +- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. +- ✅︎ indicates that the quantization method is supported on the specified hardware. +- ❌ indicates that the quantization method is not supported on the specified hardware. + +!!! note + This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. + + For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/features/quantization/bitblas.md b/docs/features/quantization/bitblas.md index 6f53a448ee364..53b689ad53ff6 100644 --- a/docs/features/quantization/bitblas.md +++ b/docs/features/quantization/bitblas.md @@ -5,7 +5,7 @@ vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more effic !!! note Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. - For details see [supported hardware](supported_hardware.md). + For details see [supported hardware](README.md#supported-hardware). Below are the steps to utilize BitBLAS with vLLM. diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md index 0661933acd61f..834c03cbe05b0 100644 --- a/docs/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio Install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` Load and run the model in `vllm`: diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md index 13b151bc7f380..5e86e9388f328 100644 --- a/docs/features/quantization/inc.md +++ b/docs/features/quantization/inc.md @@ -7,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b [Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). !!! note - Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. !!! note `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md index 127e403989944..d6fdac7b07f7f 100644 --- a/docs/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -18,7 +18,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index 45fae58a64868..247d0cbdd3f14 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -19,7 +19,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index e8ed2155375d4..047cc8382445b 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -20,7 +20,7 @@ for more installation details. Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/supported_hardware.md b/docs/features/quantization/supported_hardware.md deleted file mode 100644 index f53e69ecc6115..0000000000000 --- a/docs/features/quantization/supported_hardware.md +++ /dev/null @@ -1,33 +0,0 @@ -# Supported Hardware - -The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: - - - -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | -| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | - -- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. -- ✅︎ indicates that the quantization method is supported on the specified hardware. -- ❌ indicates that the quantization method is not supported on the specified hardware. - -!!! note - This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. - - For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 37d502ef9ce0a..afc605a504b3d 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -284,6 +284,14 @@ Supported models: Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}` +### DeepSeek-V3.1 Models (`deepseek_v31`) + +Supported models: + +* `deepseek-ai/DeepSeek-V3.1` (use with ) + +Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}` + ### Kimi-K2 Models (`kimi_k2`) Supported models: diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 7a34d47d8e494..e76ec35e1edcb 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -170,7 +170,7 @@ This value is 4GB by default. Larger space can support more concurrent requests, First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. -Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: +Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - Offline Inference: `4096 * world_size` @@ -179,7 +179,7 @@ Inference batch size is a important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes. +vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? @@ -190,6 +190,6 @@ vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage mu ### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? -- Both of them requires `amx` CPU flag. +- Both of them require `amx` CPU flag. - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md index 61b2b02aa10ba..ff912efec9ca8 100644 --- a/docs/getting_started/installation/intel_gaudi.md +++ b/docs/getting_started/installation/intel_gaudi.md @@ -261,13 +261,13 @@ Lower value corresponds to less usable graph memory reserved for prefill stage, User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: -- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode +- `max_bs` - graph capture queue will be sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode - `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. !!! note - `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. + `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt to do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index f833807666460..2af26626d207d 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -8,7 +8,7 @@ This guide will help you quickly get started with vLLM to perform: ## Prerequisites - OS: Linux -- Python: 3.9 -- 3.12 +- Python: 3.9 -- 3.13 ## Installation diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index ed5d3b0092ae7..051a2d904406d 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib import logging import sys from argparse import SUPPRESS, HelpFormatter @@ -7,25 +8,52 @@ from pathlib import Path from typing import Literal from unittest.mock import MagicMock, patch +from pydantic_core import core_schema + +logger = logging.getLogger("mkdocs") + ROOT_DIR = Path(__file__).parent.parent.parent.parent ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse" sys.path.insert(0, str(ROOT_DIR)) -sys.modules["aiohttp"] = MagicMock() -sys.modules["blake3"] = MagicMock() sys.modules["vllm._C"] = MagicMock() -from vllm.benchmarks import latency # noqa: E402 -from vllm.benchmarks import serve # noqa: E402 -from vllm.benchmarks import throughput # noqa: E402 -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402 -from vllm.entrypoints.cli.openai import ChatCommand # noqa: E402 -from vllm.entrypoints.cli.openai import CompleteCommand # noqa: E402 -from vllm.entrypoints.openai import cli_args # noqa: E402 -from vllm.entrypoints.openai import run_batch # noqa: E402 -from vllm.utils import FlexibleArgumentParser # noqa: E402 -logger = logging.getLogger("mkdocs") +class PydanticMagicMock(MagicMock): + """`MagicMock` that's able to generate pydantic-core schemas.""" + + def __get_pydantic_core_schema__(self, source_type, handler): + return core_schema.any_schema() + + +def auto_mock(module, attr, max_mocks=50): + """Function that automatically mocks missing modules during imports.""" + logger.info("Importing %s from %s", attr, module) + for _ in range(max_mocks): + try: + # First treat attr as an attr, then as a submodule + return getattr(importlib.import_module(module), attr, + importlib.import_module(f"{module}.{attr}")) + except importlib.metadata.PackageNotFoundError as e: + raise e + except ModuleNotFoundError as e: + logger.info("Mocking %s for argparse doc generation", e.name) + sys.modules[e.name] = PydanticMagicMock() + + raise ImportError( + f"Failed to import {module}.{attr} after mocking {max_mocks} imports") + + +latency = auto_mock("vllm.benchmarks", "latency") +serve = auto_mock("vllm.benchmarks", "serve") +throughput = auto_mock("vllm.benchmarks", "throughput") +AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs") +EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs") +ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand") +CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand") +cli_args = auto_mock("vllm.entrypoints.openai", "cli_args") +run_batch = auto_mock("vllm.entrypoints.openai", "run_batch") +FlexibleArgumentParser = auto_mock("vllm.utils", "FlexibleArgumentParser") class MarkdownFormatter(HelpFormatter): diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 6b4c5b31075f7..881df791698e2 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -24,7 +24,6 @@ def fix_case(text: str) -> str: "llm": "LLM", "mae": "MAE", "tpu": "TPU", - "aqlm": "AQLM", "gguf": "GGUF", "lora": "LoRA", "rlhf": "RLHF", @@ -71,6 +70,10 @@ class Example: self.other_files = self.determine_other_files() self.title = self.determine_title() + @property + def is_code(self) -> bool: + return self.main_file.suffix != ".md" + def determine_main_file(self) -> Path: """ Determines the main file in the given path. @@ -102,6 +105,12 @@ class Example: return [file for file in self.path.rglob("*") if is_other_file(file)] def determine_title(self) -> str: + if not self.is_code: + with open(self.main_file) as f: + first_line = f.readline().strip() + match = re.match(r'^#\s+(?P.+)$', first_line) + if match: + return match.group('title') return fix_case(self.path.stem.replace("_", " ").title()) def generate(self) -> str: @@ -111,11 +120,13 @@ class Example: # Use long code fence to avoid issues with # included files containing code fences too code_fence = "``````" - is_code = self.main_file.suffix != ".md" - if is_code: + # Skip the title from md snippets as it's been included above + start_line = 2 + if self.is_code: content += f"{code_fence}{self.main_file.suffix[1:]}\n" - content += f'--8<-- "{self.main_file}"\n' - if is_code: + start_line = 1 + content += f'--8<-- "{self.main_file}:{start_line}"\n' + if self.is_code: content += f"{code_fence}\n" content += "\n" diff --git a/docs/mkdocs/javascript/mathjax.js b/docs/mkdocs/javascript/mathjax.js new file mode 100644 index 0000000000000..5da0d443578c4 --- /dev/null +++ b/docs/mkdocs/javascript/mathjax.js @@ -0,0 +1,20 @@ +// Enables MathJax rendering +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.startup.output.clearCache() + MathJax.typesetClear() + MathJax.texReset() + MathJax.typesetPromise() +}) diff --git a/docs/models/generative_models.md b/docs/models/generative_models.md index a64ecd31ebaef..d02522a6657de 100644 --- a/docs/models/generative_models.md +++ b/docs/models/generative_models.md @@ -19,7 +19,7 @@ Run a model in generation mode via the option `--runner generate`. ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. -See [configuration](../api/summary.md#configuration) for a list of options when initializing the model. +See [configuration](../api/README.md#configuration) for a list of options when initializing the model. ### `LLM.generate` diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 39f209d0eb7ed..fbb5f6f6dd171 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -81,7 +81,7 @@ which takes priority over both the model's and Sentence Transformers's defaults. ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. -See [configuration](../api/summary.md#configuration) for a list of options when initializing the model. +See [configuration](../api/README.md#configuration) for a list of options when initializing the model. ### `LLM.embed` @@ -205,12 +205,12 @@ Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions. -For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online). +For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf-overrides '{"is_matryoshka": true}'`, `--hf-overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online). Here is an example to serve a model with Matryoshka Embeddings enabled. ```text -vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}' +vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf-overrides '{"matryoshka_dimensions":[256]}' ``` ### Offline Inference diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a24fa4bcce333..35a5fa0c2e42f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -328,10 +328,11 @@ th { | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | +| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | | `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ | @@ -362,7 +363,7 @@ th { | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | | `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ | | `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ | @@ -372,6 +373,7 @@ th { | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | @@ -383,8 +385,8 @@ th { | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | @@ -399,6 +401,7 @@ th { | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | | `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -418,6 +421,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. +!!! note + Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture. + ### Pooling Models See [this page](./pooling_models.md) for more information on how to use pooling models. @@ -432,17 +438,17 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | -| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | -| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | | -| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | | -| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | | -| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | | +| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ | +| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | +| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ | +| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ | +| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | ✅︎ | +| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | ✅︎ | | `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | <sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion)) @@ -472,7 +478,7 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | @@ -489,12 +495,12 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | | +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | | -| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) @@ -609,6 +615,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | @@ -620,9 +628,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | | `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | | `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | @@ -630,13 +638,14 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -647,6 +656,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | +| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ | @@ -692,7 +702,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups. !!! note - Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. + For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc), InternVL3 and InternVL3.5 have video inputs support currently. !!! note To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 54af970ea842d..20234e7611333 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -107,15 +107,16 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. +Please note that prefix caching is not yet supported for these models. Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that -these models currently require disabling prefix caching and using the FlashInfer attention backend in V1. +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). +Please note that prefix caching is not yet supported for these models. Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`). -Please note that these models currently require disabling prefix caching, enforcing eager mode, and using the FlashInfer -attention backend in V1. +Please note that prefix caching is not yet supported for these models. +It is also necessary to enforce eager mode for these models in V1. #### Encoder-Decoder Models @@ -154,16 +155,19 @@ differences compared to V0: ##### Logprobs Calculation -Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. +By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. before applying any logits post-processing such as temperature scaling or penalty adjustments). As a result, the returned logprobs do not reflect the final adjusted probabilities used during sampling. -Support for logprobs with post-sampling adjustments is in progress and will be added in future updates. +You can adjust this behavior by setting the `--logprobs-mode` flag. +Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`. +Raw means the values before applying any logit processors, like bad words. +Processed means the values after applying all processors, including temperature and top_k/top_p. ##### Prompt Logprobs with Prefix Caching -Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414). +Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs. #### Deprecated Features diff --git a/examples/offline_inference/basic/README.md b/examples/offline_inference/basic/README.md index 0a2bd6e2b70b3..cbb3116e97414 100644 --- a/examples/offline_inference/basic/README.md +++ b/examples/offline_inference/basic/README.md @@ -52,20 +52,6 @@ Try it yourself with the following argument: ### Quantization -#### AQLM - -vLLM supports models that are quantized using AQLM. - -Try one yourself by passing one of the following models to the `--model` argument: - -- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf` -- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf` -- `ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf` -- `ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf` -- `BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf` - -> Some of these models are likely to be too large for a single GPU. You can split them across multiple GPUs by setting `--tensor-parallel-size` to the number of required GPUs. - #### GGUF vLLM supports models that are quantized using GGUF. diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index dbf8ed58cc477..dd7559451c4c6 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -70,12 +70,27 @@ def parse_args(): default=64, help=("Maximum number of sequences to be processed in a single iteration."), ) + parser.add_argument( + "--max-model-len", + type=int, + help=("Maximum number of tokens to be processed in a single iteration."), + ) + parser.add_argument( + "--timeout", + type=int, + default=300, + help=("Number of seconds before unresponsive process is killed."), + ) parser.add_argument( "--gpu-memory-utilization", type=float, default=0.8, help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), ) + parser.add_argument( + "--quantization", + type=str, + ) return parser.parse_args() @@ -90,7 +105,9 @@ def main( enforce_eager, trust_remote_code, max_num_seqs, + max_model_len, gpu_memory_utilization, + quantization, ): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) @@ -142,7 +159,9 @@ def main( enable_expert_parallel=True, trust_remote_code=trust_remote_code, max_num_seqs=max_num_seqs, + max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, + quantization=quantization, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -198,14 +217,16 @@ if __name__ == "__main__": args.enforce_eager, args.trust_remote_code, args.max_num_seqs, + args.max_model_len, args.gpu_memory_utilization, + args.quantization, ), ) proc.start() procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=300) + proc.join(timeout=args.timeout) if proc.exitcode is None: print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") proc.kill() diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py new file mode 100644 index 0000000000000..d2ba27cd1e027 --- /dev/null +++ b/examples/offline_inference/dolphin.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +import regex as re +from PIL import Image +from transformers import DonutProcessor + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.multimodal.utils import fetch_image + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +@dataclass +class ImageDimensions: + original_w: int + original_h: int + padded_w: int + padded_h: int + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def map_to_original_coordinates( + x1, y1, x2, y2, dims: ImageDimensions +) -> tuple[int, int, int, int]: + try: + top = (dims.padded_h - dims.original_h) // 2 + left = (dims.padded_w - dims.original_w) // 2 + orig_x1 = max(0, x1 - left) + orig_y1 = max(0, y1 - top) + orig_x2 = min(dims.original_w, x2 - left) + orig_y2 = min(dims.original_h, y2 - top) + if orig_x2 <= orig_x1: + orig_x2 = min(orig_x1 + 1, dims.original_w) + if orig_y2 <= orig_y1: + orig_y2 = min(orig_y1 + 1, dims.original_h) + return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) + except Exception as e: + print(f"map_to_original_coordinates error: {str(e)}") + return 0, 0, min(100, dims.original_w), min(100, dims.original_h) + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): + if isinstance(image, str): + image = cv2.imread(image) + img_h, img_w = image.shape[:2] + new_boxes = [] + for box in boxes: + best_box = copy.deepcopy(box) + + def check_edge(img, current_box, i, is_vertical): + edge = current_box[i] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold( + gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + if is_vertical: + line = binary[current_box[1] : current_box[3] + 1, edge] + else: + line = binary[edge, current_box[0] : current_box[2] + 1] + transitions = np.abs(np.diff(line)) + return np.sum(transitions) / len(transitions) + + edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] + current_box = copy.deepcopy(box) + current_box[0] = min(max(current_box[0], 0), img_w - 1) + current_box[1] = min(max(current_box[1], 0), img_h - 1) + current_box[2] = min(max(current_box[2], 0), img_w - 1) + current_box[3] = min(max(current_box[3], 0), img_h - 1) + + for i, direction, is_vertical in edges: + best_score = check_edge(image, current_box, i, is_vertical) + if best_score <= threshold: + continue + for step in range(max_pixels): + current_box[i] += direction + if i == 0 or i == 2: + current_box[i] = min(max(current_box[i], 0), img_w - 1) + else: + current_box[i] = min(max(current_box[i], 0), img_h - 1) + score = check_edge(image, current_box, i, is_vertical) + if score < best_score: + best_score = score + best_box = copy.deepcopy(current_box) + if score <= threshold: + break + new_boxes.append(best_box) + return new_boxes + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): + try: + x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) + x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) + x1, y1, x2, y2 = new_boxes[0] + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + if previous_box is not None: + prev_x1, prev_y1, prev_x2, prev_y2 = previous_box + if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): + y1 = prev_y2 + y1 = min(y1, dims.padded_h - 1) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_previous_box = [x1, y1, x2, y2] + orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( + x1, y1, x2, y2, dims + ) + return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box + except Exception as e: + print(f"process_coordinates error: {str(e)}") + orig_x1, orig_y1, orig_x2, orig_y2 = ( + 0, + 0, + min(100, dims.original_w), + min(100, dims.original_h), + ) + return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: + try: + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + original_h, original_w = image_cv.shape[:2] + max_size = max(original_h, original_w) + top = (max_size - original_h) // 2 + bottom = max_size - original_h - top + left = (max_size - original_w) // 2 + right = max_size - original_w - left + padded_image = cv2.copyMakeBorder( + image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + padded_h, padded_w = padded_image.shape[:2] + dimensions = ImageDimensions( + original_w=original_w, + original_h=original_h, + padded_w=padded_w, + padded_h=padded_h, + ) + return padded_image, dimensions + except Exception as e: + print(f"prepare_image error: {str(e)}") + h, w = image.height, image.width + dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) + return np.zeros((h, w, 3), dtype=np.uint8), dimensions + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def parse_layout_string(bbox_str): + """Parse layout string using regular expressions""" + pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" + matches = re.finditer(pattern, bbox_str) + + parsed_results = [] + for match in matches: + coords = [float(match.group(i)) for i in range(1, 5)] + label = match.group(5).strip() + parsed_results.append((coords, label)) + + return parsed_results + + +model_id = "ByteDance/Dolphin" + +# The input image size for Dolphin is 896 x 896, +# and the patch_size is 4 x 4. +# Therefore, the initial number of patches is: +# Height: 896 / 4 = 224 patches +# Width: 896 / 4 = 224 patches + +# The Dolphin model uses a staged downsampling approach, +# defined by the "depths": [2, 2, 14, 2] configuration. +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, +# which halves the feature map's dimensions (dividing both height and width by 2). +# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. +# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. +# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. + +# Because vLLM needs to fill the image features with an encoder_prompt, +# and the encoder_prompt will have `<pad>` tokens added when tokenized, +# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. +encoder_prompt = "".join(["0"] * 783) +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=2048, +) + +processor = DonutProcessor.from_pretrained(model_id) +llm = LLM( + model=model_id, + dtype="float16", + max_num_seqs=8, + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, +) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--image_path", type=str, default=None, help="Path to a local image file." +) +args = parser.parse_args() + +if args.image_path: + if not os.path.exists(args.image_path): + raise FileNotFoundError(f"Error: File not found at {args.image_path}") + image = Image.open(args.image_path).convert("RGB") +else: + image = fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) + + +prompt = "Parse the reading order of this document. " +decoder_prompt = f"<s>{prompt}<Answer/>" +decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] +) +enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=decoder_prompt_tokens, +) +layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) +layout_result_str = layout_outputs[0].outputs[0].text +print(f"Layout analysis output:\n{layout_result_str}") + +padded_image, dims = prepare_image(image) +layout_results = parse_layout_string(layout_result_str) +text_table_elements = [] +previous_box = None +reading_order = 0 +for bbox_coords, label in layout_results: + if label == "fig": + continue + try: + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( + process_coordinates(bbox_coords, padded_image, dims, previous_box) + ) + cropped = padded_image[y1:y2, x1:x2] + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) + prompt_ocr = ( + "Parse the table in the image. " + if label == "tab" + else "Read text in the image. " + ) + text_table_elements.append( + { + "crop": pil_crop, + "prompt": prompt_ocr, + "reading_order": reading_order, + } + ) + reading_order += 1 + except Exception as e: + print(f"Error processing bbox (label: {label}): {str(e)}") + continue + +if text_table_elements: + batch_prompts = [] + for elem in text_table_elements: + decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer( + decoder_prompt_str, add_special_tokens=False + )["input_ids"] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} + ), + decoder_prompt=decoder_prompt_tokens, + ) + batch_prompts.append(enc_dec_prompt) + batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) + for i, output in enumerate(batch_outputs): + text_table_elements[i]["text"] = output.outputs[0].text.strip() + +print("------" * 8) +text_table_elements.sort(key=lambda x: x["reading_order"]) +for elem in text_table_elements: + print(elem.get("text", "")) diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index 0da6fa5c4af5f..df6c1eaf4a21e 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -2,9 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Demonstrate prompting of text-to-text -encoder/decoder models, specifically BART +encoder/decoder models, specifically BART and mBART. + +This script is refactored to allow model selection via command-line arguments. """ +import argparse +from typing import NamedTuple, Optional + from vllm import LLM, SamplingParams from vllm.inputs import ( ExplicitEncoderDecoderPrompt, @@ -14,119 +19,175 @@ from vllm.inputs import ( ) -def create_prompts(tokenizer): - # Test prompts - # - # This section shows all of the valid ways to prompt an - # encoder/decoder model. - # - # - Helpers for building prompts - text_prompt_raw = "Hello, my name is" - text_prompt = TextPrompt(prompt="The president of the United States is") +class ModelRequestData(NamedTuple): + """ + Holds the configuration for a specific model, including its + HuggingFace ID and the prompts to use for the demo. + """ + + model_id: str + encoder_prompts: list + decoder_prompts: list + hf_overrides: Optional[dict] = None + + +def get_bart_config() -> ModelRequestData: + """ + Returns the configuration for facebook/bart-large-cnn. + This uses the exact test cases from the original script. + """ + encoder_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "An encoder prompt", + ] + decoder_prompts = [ + "A decoder prompt", + "Another decoder prompt", + ] + return ModelRequestData( + model_id="facebook/bart-large-cnn", + encoder_prompts=encoder_prompts, + decoder_prompts=decoder_prompts, + ) + + +def get_mbart_config() -> ModelRequestData: + """ + Returns the configuration for facebook/mbart-large-en-ro. + This uses prompts suitable for an English-to-Romanian translation task. + """ + encoder_prompts = [ + "The quick brown fox jumps over the lazy dog.", + "How are you today?", + ] + decoder_prompts = ["", ""] + hf_overrides = {"architectures": ["MBartForConditionalGeneration"]} + return ModelRequestData( + model_id="facebook/mbart-large-en-ro", + encoder_prompts=encoder_prompts, + decoder_prompts=decoder_prompts, + hf_overrides=hf_overrides, + ) + + +MODEL_GETTERS = { + "bart": get_bart_config, + "mbart": get_mbart_config, +} + + +def create_all_prompt_types( + encoder_prompts_raw: list, + decoder_prompts_raw: list, + tokenizer, +) -> list: + """ + Generates a list of diverse prompt types for demonstration. + This function is generic and uses the provided raw prompts + to create various vLLM input objects. + """ + text_prompt_raw = encoder_prompts_raw[0] + text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)]) tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode(prompt="The capital of France is") - ) - # - Pass a single prompt to encoder/decoder model - # (implicitly encoder input prompt); - # decoder input prompt is assumed to be None - - single_text_prompt_raw = text_prompt_raw # Pass a string directly - single_text_prompt = text_prompt # Pass a TextPrompt - single_tokens_prompt = tokens_prompt # Pass a TokensPrompt - - # ruff: noqa: E501 - # - Pass explicit encoder and decoder input prompts within one data structure. - # Encoder and decoder prompts can both independently be text or tokens, with - # no requirement that they be the same prompt type. Some example prompt-type - # combinations are shown below, note that these are not exhaustive. - - enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt string directly, & - # pass decoder prompt tokens - encoder_prompt=single_text_prompt_raw, - decoder_prompt=single_tokens_prompt, - ) - enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( - # Pass TextPrompt to encoder, and - # pass decoder prompt string directly - encoder_prompt=single_text_prompt, - decoder_prompt=single_text_prompt_raw, - ) - enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt tokens directly, and - # pass TextPrompt to decoder - encoder_prompt=single_tokens_prompt, - decoder_prompt=single_text_prompt, + prompt_token_ids=tokenizer.encode( + encoder_prompts_raw[2 % len(encoder_prompts_raw)] + ) ) - # - Finally, here's a useful helper function for zipping encoder and - # decoder prompts together into a list of ExplicitEncoderDecoderPrompt - # instances + decoder_tokens_prompt = TokensPrompt( + prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0]) + ) + single_prompt_examples = [ + text_prompt_raw, + text_prompt, + tokens_prompt, + ] + explicit_pair_examples = [ + ExplicitEncoderDecoderPrompt( + encoder_prompt=text_prompt_raw, + decoder_prompt=decoder_tokens_prompt, + ), + ExplicitEncoderDecoderPrompt( + encoder_prompt=text_prompt, + decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)], + ), + ExplicitEncoderDecoderPrompt( + encoder_prompt=tokens_prompt, + decoder_prompt=text_prompt, + ), + ] zipped_prompt_list = zip_enc_dec_prompts( - ["An encoder prompt", "Another encoder prompt"], - ["A decoder prompt", "Another decoder prompt"], + encoder_prompts_raw, + decoder_prompts_raw, ) - - # - Let's put all of the above example prompts together into one list - # which we will pass to the encoder/decoder LLM. - return [ - single_text_prompt_raw, - single_text_prompt, - single_tokens_prompt, - enc_dec_prompt1, - enc_dec_prompt2, - enc_dec_prompt3, - ] + zipped_prompt_list + return single_prompt_examples + explicit_pair_examples + zipped_prompt_list -# Create a sampling params object. -def create_sampling_params(): +def create_sampling_params() -> SamplingParams: + """Create a sampling params object.""" return SamplingParams( temperature=0, top_p=1.0, min_tokens=0, - max_tokens=20, + max_tokens=30, ) -# Print the outputs. -def print_outputs(outputs): - print("-" * 50) +def print_outputs(outputs: list): + """Formats and prints the generation outputs.""" + print("-" * 80) for i, output in enumerate(outputs): prompt = output.prompt encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text print(f"Output {i + 1}:") - print( - f"Encoder prompt: {encoder_prompt!r}\n" - f"Decoder prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}" + print(f"Encoder Prompt: {encoder_prompt!r}") + print(f"Decoder Prompt: {prompt!r}") + print(f"Generated Text: {generated_text!r}") + print("-" * 80) + + +def main(args): + """Main execution function.""" + model_key = args.model + if model_key not in MODEL_GETTERS: + raise ValueError( + f"Unknown model: {model_key}. " + f"Available models: {list(MODEL_GETTERS.keys())}" ) - print("-" * 50) + config_getter = MODEL_GETTERS[model_key] + model_config = config_getter() - -def main(): - dtype = "float" - - # Create a BART encoder/decoder model instance + print(f"🚀 Running demo for model: {model_config.model_id}") llm = LLM( - model="facebook/bart-large-cnn", - dtype=dtype, + model=model_config.model_id, + dtype="float", + hf_overrides=model_config.hf_overrides, ) - - # Get BART tokenizer tokenizer = llm.llm_engine.get_tokenizer_group() - - prompts = create_prompts(tokenizer) + prompts = create_all_prompt_types( + encoder_prompts_raw=model_config.encoder_prompts, + decoder_prompts_raw=model_config.decoder_prompts, + tokenizer=tokenizer, + ) sampling_params = create_sampling_params() - - # Generate output tokens from the prompts. The output is a list of - # RequestOutput objects that contain the prompt, generated - # text, and other information. outputs = llm.generate(prompts, sampling_params) - print_outputs(outputs) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser( + description="A flexible demo for vLLM encoder-decoder models." + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="bart", + choices=MODEL_GETTERS.keys(), + help="The short name of the model to run.", + ) + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902edb7e7..655f9f3fce7ae 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -13,6 +13,7 @@ from typing import NamedTuple from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] +def run_donut(): + engine_args = EngineArgs( + model="naver-clova-ix/donut-base-finetuned-docvqa", + max_num_seqs=2, + limit_mm_per_prompt={"image": 1}, + dtype="float16", + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, + ) + + # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, + # and the patch_size is 4 x 4. + # Therefore, the initial number of patches is: + # Height: 1920 / 4 = 480 patches + # Width: 2560 / 4 = 640 patches + # The Swin model uses a staged downsampling approach, + # defined by the "depths": [2, 2, 14, 2] configuration. + # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, + # which halves the feature map's dimensions (dividing both height and width by 2). + # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. + # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. + # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. + # Because vLLM needs to fill the image features with an encoder_prompt, + # and the encoder_prompt will have `<pad>` tokens added when tokenized, + # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. + prompts = [ + { + "encoder_prompt": { + "prompt": "".join(["$"] * 4799), + "multi_modal_data": { + "image": fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) # noqa: E501 + }, + }, + "decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501 + }, + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", @@ -118,6 +163,7 @@ def run_whisper(): model_example_map = { + "donut": run_donut, "florence2": run_florence2, "mllama": run_mllama, "whisper": run_whisper, diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor.py new file mode 100644 index 0000000000000..7ef20efa7d28c --- /dev/null +++ b/examples/offline_inference/logits_processor.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates instantiating vLLM with a custom logits processor +class object. + +For a basic example of implementing a custom logits processor, see +the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`. + +For testing purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.config import VllmConfig +from vllm.v1.sample.logits_processor import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) + + +# Hypothetical custom logits processor +class DummyLogitsProcessor(LogitsProcessor): + """Fake logit processor to support unit testing and examples""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + self.req_info: dict[int, SamplingParams] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + # Process added requests. + for index, params, _, _ in batch_update.added: + assert params is not None + if params.extra_args and ( + target_token := params.extra_args.get("target_token") + ): + self.req_info[index] = target_token + + if self.req_info: + # Process removed requests. + for index in batch_update.removed: + self.req_info.pop(index, None) + + # Process moved requests, unidirectional move (a->b) and swap + # (a<->b) + for adx, bdx, direct in batch_update.moved: + a_val = self.req_info.pop(adx, None) + b_val = self.req_info.pop(bdx, None) + if a_val is not None: + self.req_info[bdx] = a_val + if direct == MoveDirectionality.SWAP and b_val is not None: + self.req_info[adx] = b_val + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.req_info: + return logits + + # Save target values before modification + rows_list = list(self.req_info.keys()) + cols = torch.tensor( + [self.req_info[i] for i in rows_list], + dtype=torch.long, + device=logits.device, + ) + rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) + values_to_keep = logits[rows, cols].clone() + + # Mask all but target tokens + logits[rows] = float("-inf") + logits[rows, cols] = values_to_keep + + return logits + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[DummyLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 184c30891eca7..c4972f02d0f8e 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt from vllm.v1.metrics.reader import Counter, Vector try: @@ -137,7 +138,8 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - prompt_token_ids=prompt_ids, sampling_params=sampling_params + TokensPrompt(prompt_token_ids=prompt_ids), + sampling_params=sampling_params, ) else: outputs = llm.chat(prompts, sampling_params=sampling_params) diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index 8ef121ebe848e..88d87beb4874d 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -15,6 +15,8 @@ from pydantic import BaseModel from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams +MAX_TOKENS = 50 + # Guided decoding by Choice (list of possible options) guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) @@ -23,7 +25,9 @@ prompt_choice = "Classify this sentiment: vLLM is wonderful!" # Guided decoding by Regex guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") sampling_params_regex = SamplingParams( - guided_decoding=guided_decoding_params_regex, stop=["\n"] + guided_decoding=guided_decoding_params_regex, + stop=["\n"], + max_tokens=MAX_TOKENS, ) prompt_regex = ( "Generate an email address for Alan Turing, who works in Enigma." @@ -48,7 +52,10 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() guided_decoding_params_json = GuidedDecodingParams(json=json_schema) -sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json) +sampling_params_json = SamplingParams( + guided_decoding=guided_decoding_params_json, + max_tokens=MAX_TOKENS, +) prompt_json = ( "Generate a JSON with the brand, model and car_type of" "the most iconic car from the 90's" @@ -64,7 +71,10 @@ condition ::= column "= " number number ::= "1 " | "2 " """ guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) -sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar) +sampling_params_grammar = SamplingParams( + guided_decoding=guided_decoding_params_grammar, + max_tokens=MAX_TOKENS, +) prompt_grammar = ( "Generate an SQL query to show the 'username' and 'email'from the 'users' table." ) @@ -75,7 +85,7 @@ def format_output(title: str, output: str): def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM): - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) return outputs[0].outputs[0].text diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 988ad35cdd7e6..4e879666f61d7 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) +# Ernie4.5-VL +def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + if modality == "image": + placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + elif modality == "video": + placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + prompts = [ + ( + f"<|begin_of_sentence|>User: {question}{placeholder}\n" + "Assistant: <think></think>" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Florence2 def run_florence2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -283,8 +314,10 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" + ( + "<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>" + f"{question}<|assistant|>" + ) for question in questions ] @@ -333,6 +366,80 @@ def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: ) +# GLM-4.5V +def run_glm4_5v(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# GLM-4.5V-FP8 +def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V-FP8" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # H2OVL-Mississippi def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -383,8 +490,8 @@ def run_hyperclovax_seed_vision( for question in questions: if modality == "image": """ - ocr: List the words in the image in raster order. - Even if the word order feels unnatural for reading, + ocr: List the words in the image in raster order. + Even if the word order feels unnatural for reading, the model will handle it as long as it follows raster order. e.g. "Naver, CLOVA, bigshane" lens_keywords: List the entity names in the image. @@ -693,15 +800,13 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData: if modality == "video": prompts = [ - f"<|im_start|>user <video>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <video>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] elif modality == "image": prompts = [ - f"<|im_start|>user <image>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] @@ -815,6 +920,39 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData: return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6") +def run_minimax_vl_01(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "MiniMaxAI/MiniMax-VL-01" + + engine_args = EngineArgs( + model=model_name, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + tensor_parallel_size=8, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [ + [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": question}], + } + ] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Mistral-3 HF-format def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -891,8 +1029,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|im_start|>user <image>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] @@ -998,6 +1135,38 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData: ) +# Ovis2_5 +def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "AIDC-AI/Ovis2.5-2B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={modality: 1}, + ) + if modality == "image": + placeholder = "<image>" + elif modality == "video": + placeholder = "<video>" + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"{placeholder}\n{question}"}] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # PaliGemma def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1297,6 +1466,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ) +# R-4B +def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "YannQi/R-4B" + + prompts = [ + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" + for question in questions + ] + + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + limit_mm_per_prompt={modality: 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1442,12 +1633,15 @@ model_example_map = { "chameleon": run_chameleon, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, + "ernie45_vl": run_ernie45_vl, "florence2": run_florence2, "fuyu": run_fuyu, "gemma3": run_gemma3, "gemma3n": run_gemma3n, "glm4v": run_glm4v, "glm4_1v": run_glm4_1v, + "glm4_5v": run_glm4_5v, + "glm4_5v_fp8": run_glm4_5v_fp8, "h2ovl_chat": run_h2ovl, "hyperclovax_seed_vision": run_hyperclovax_seed_vision, "idefics3": run_idefics3, @@ -1463,12 +1657,14 @@ model_example_map = { "mantis": run_mantis, "minicpmo": run_minicpmo, "minicpmv": run_minicpmv, + "minimax_vl_01": run_minimax_vl_01, "mistral3": run_mistral3, "mllama": run_mllama, "molmo": run_molmo, "nemotron_vl": run_nemotron_vl, "NVLM_D": run_nvlm_d, "ovis": run_ovis, + "ovis2_5": run_ovis2_5, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, @@ -1479,6 +1675,7 @@ model_example_map = { "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_omni": run_qwen2_5_omni, + "rvl": run_r_vl, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, "step3": run_step3, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 799337ed68503..d9242efa85470 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -680,6 +680,36 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData: ) +# ovis2_5 +def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "AIDC-AI/Ovis2.5-2B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistral-community/pixtral-12b" @@ -962,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "YannQi/R-4B" + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" @@ -1064,6 +1127,76 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData: ) +# GLM-4.5V +def load_glm4_5v(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V" + + engine_args = EngineArgs( + model=model_name, + max_model_len=32768, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + enforce_eager=True, + tensor_parallel_size=4, + ) + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + processor = AutoProcessor.from_pretrained(model_name) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + +# GLM-4.5V-FP8 +def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V-FP8" + + engine_args = EngineArgs( + model=model_name, + max_model_len=32768, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + enforce_eager=True, + tensor_parallel_size=4, + ) + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + processor = AutoProcessor.from_pretrained(model_name) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + model_example_map = { "aria": load_aria, "aya_vision": load_aya_vision, @@ -1085,6 +1218,7 @@ model_example_map = { "mllama": load_mllama, "NVLM_D": load_nvlm_d, "ovis": load_ovis, + "ovis2_5": load_ovis2_5, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "phi4_multimodal": load_phi4_multimodal, @@ -1092,10 +1226,13 @@ model_example_map = { "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, "qwen2_5_vl": load_qwen2_5_vl, + "rvl": load_r_vl, "smolvlm": load_smolvlm, "step3": load_step3, "tarsier": load_tarsier, "tarsier2": load_tarsier2, + "glm4_5v": load_glm4_5v, + "glm4_5v_fp8": load_glm4_5v_fp8, } diff --git a/examples/tool_chat_template_deepseekv31.jinja b/examples/tool_chat_template_deepseekv31.jinja new file mode 100644 index 0000000000000..863be69d60b68 --- /dev/null +++ b/examples/tool_chat_template_deepseekv31.jinja @@ -0,0 +1,91 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} +{% if not thinking is defined %} + {% set thinking = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %} +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{% if tools is defined and tools is not none %} + {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} + {% for tool in tools %} + {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} + {% endfor %} + {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} +{% endif %} + +{{ bos_token }}{{ ns.system_prompt }} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + message['content']}} + {%- endif %} + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- if ns.is_last_user %} + {{'<|Assistant|></think>'}} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} + {%- if ns.is_last_user %} + {{'<|Assistant|>'}} + {%- if message['prefix'] is defined and message['prefix'] and thinking %} + {{'<think>'}} + {%- else %} + {{'</think>'}} + {%- endif %} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {%- set content = message['content'] -%} + {%- if '</think>' in content %} + {%- set content = content.split('</think>', 1)[1] -%} + {%- endif %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} +{%- endfor -%} +{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} + {%- if not thinking %} + {{'</think>'}} + {%- else %} + {{'<think>'}} + {%- endif %} +{% endif %} diff --git a/examples/tool_chat_template_gemma3_pythonic.jinja b/examples/tool_chat_template_gemma3_pythonic.jinja new file mode 100644 index 0000000000000..5a20b01911295 --- /dev/null +++ b/examples/tool_chat_template_gemma3_pythonic.jinja @@ -0,0 +1,123 @@ +{#- Begin-of-sequence token to start the model prompt -#} +{{ bos_token }} +{#- Extracts the system message. Gemma does not support system messages so it will be prepended to first user message. -#} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{#- Set tools to none if not defined for this ChatCompletion request (helps avoid errors later) -#} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{#- Validate alternating user/assistant messages (excluding 'tool' messages and ones with tool_calls) -#} +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | selectattr("tool_calls", "undefined") -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} +{%- endfor -%} + +{#- Main loop over all messages in the conversation history -#} +{%- for message in loop_messages -%} + {#- Normalize roles for model prompt formatting -#} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- elif (message['role'] == 'tool') -%} + {%- set role = "user" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {#- Mark the start of a message block with the appropriate role -#} + {{ '<start_of_turn>' + role + '\n' -}} + + {#- Insert system message content (if present) at the beginning of the first message. -#} + {%- if loop.first -%} + {{ first_user_prefix }} + {#- Append system message with tool information if using tools in message request. -#} + {%- if tools is not none -%} + {{- "Tools (functions) are available. If you decide to invoke one or more of the tools, you must respond with a python list of the function calls.\n" -}} + {{- "Example Format: [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] \n" -}} + {{- "Do not use variables. DO NOT USE MARKDOWN SYNTAX. You SHOULD NOT include any other text in the response if you call a function. If none of the functions can be used, point it out. If you lack the parameters required by the function, also point it out.\n" -}} + {{- "Here is a list of functions in JSON format that you can invoke.\n" -}} + {{- tools | tojson(indent=4) -}} + {{- "\n\n" -}} + {%- endif -%} + {%- endif -%} + + {#- Format model tool calls (turns where model indicates they want to call a tool) -#} + {%- if 'tool_calls' in message -%} + {#- Opening bracket for tool call list. -#} + {{- '[' -}} + {#- For each tool call -#} + {%- for tool_call in message.tool_calls -%} + {#- Get tool call function. -#} + {%- if tool_call.function is defined -%} + {%- set tool_call = tool_call.function -%} + {%- endif -%} + {#- Function name & opening parenthesis. -#} + {{- tool_call.name + '(' -}} + + {#-- Handle arguments as list (positional) or dict (named) --#} + {#-- Named arguments (dict) --#} + {%- if tool_call.arguments is iterable and tool_call.arguments is mapping -%} + {%- set first = true -%} + {%- for key, val in tool_call.arguments.items() -%} + {%- if not first %}, {% endif -%} + {{ key }}={{ val | tojson }} + {%- set first = false -%} + {%- endfor -%} + {#-- Positional arguments (list) --#} + {%- elif tool_call.arguments is iterable -%} + {{- tool_call.arguments | map('tojson') | join(', ') -}} + {#-- Fallback: single positional value --#} + {%- else -%} + {{- tool_call.arguments | tojson -}} + {#-- Closing parenthesis. --#} + {%- endif -%} + {{- ')' -}} + {#-- If more than one tool call, place comma and move to formatting next tool call --#} + {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + {#- Closing bracket for tool call list. -#} + {{- ']' -}} + {%- endif -%} + + {#- Tool response start tag (for messages from a tool) -#} + {%- if (message['role'] == 'tool') -%} + {{ '<tool_response>\n' -}} + {%- endif -%} + + {#- Render the message content: handle plain string or multimodal content like image/text -#} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '<start_of_image>' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + + {#- Tool response end tag -#} + {%- if (message['role'] == 'tool') -%} + {{ '</tool_response>' -}} + {%- endif -%} + + {#- Mark end of a single turn -#} + {{ '<end_of_turn>\n' }} +{%- endfor -%} + +{#- If generation is to be triggered, add model prompt prefix -#} +{%- if add_generation_prompt -%} + {{'<start_of_turn>model\n'}} +{%- endif -%} \ No newline at end of file diff --git a/examples/tool_chat_template_phi4_mini.jinja b/examples/tool_chat_template_phi4_mini.jinja index 36423b6c4240a..83886762c2893 100644 --- a/examples/tool_chat_template_phi4_mini.jinja +++ b/examples/tool_chat_template_phi4_mini.jinja @@ -1,10 +1,14 @@ -{%- if messages %} - {%- if system_message or tools %} -<|system|> - -{%- if system_message %} -{{ system_message }} +{%- if messages and messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant." %} {%- endif %} + +{%- if messages %} +<|system|> +{{ system_message }} +{%- if tools %} In addition to plain text responses, you can chose to call one or more of the provided functions. Use the following rule to decide when to call a function: @@ -19,13 +23,11 @@ If you decide to call functions: * make sure you pick the right functions that match the user intent -{%- if tools %} {%- for t in tools %} {{- t | tojson(indent=4) }} {{- "\n\n" }} {%- endfor %} {%- endif %}<|end|> - {%- endif %} {%- for message in messages %} {%- if message.role != "system" %} diff --git a/examples/tool_chat_template_qwen3coder.jinja b/examples/tool_chat_template_qwen3coder.jinja new file mode 100644 index 0000000000000..49b0e8d0ee7e6 --- /dev/null +++ b/examples/tool_chat_template_qwen3coder.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "<tools>" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }} + {%- if tool.description is defined %} + {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }} + {%- endif %} + {{- '\n<parameters>' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n<parameter>' }} + {{- '\n<name>' ~ param_name ~ '</name>' }} + {%- if param_fields.type is defined %} + {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n</parameter>' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n</parameters>' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n</function>' }} + {%- endfor %} + {{- "\n</tools>" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '<parameter=' + args_name + '>\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n</parameter>\n' }} + {%- endfor %} + {%- endif %} + {{- '</function>\n</tool_call>' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '<tool_response>\n' }} + {{- message.content }} + {{- '\n</tool_response>\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/mkdocs.yaml b/mkdocs.yaml index 47fe1ebce9712..507a80c41e8b4 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -129,15 +129,16 @@ markdown_extensions: - toc: permalink: true # For math rendering - - mdx_math: - enable_dollar_delimiter: true + - pymdownx.arithmatex: + generic: true extra_css: - mkdocs/stylesheets/extra.css extra_javascript: - mkdocs/javascript/run_llm_widget.js - - https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML + - mkdocs/javascript/mathjax.js + - https://unpkg.com/mathjax@3.2.2/es5/tex-mml-chtml.js - mkdocs/javascript/edit_and_feedback.js - mkdocs/javascript/slack_and_forum.js diff --git a/pyproject.toml b/pyproject.toml index 03a32ac0ba3d7..013f2a6cd59e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,13 +24,14 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Intended Audience :: Developers", "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ] -requires-python = ">=3.9,<3.13" +requires-python = ">=3.9,<3.14" dynamic = [ "version", "dependencies", "optional-dependencies"] [project.urls] diff --git a/requirements/common.txt b/requirements/common.txt index 1a8fea0dd7d93..e21abfb9a30bd 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,20 +7,21 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.55.0 +transformers >= 4.55.2 tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.99.1 # For Responses API with reasoning content -pydantic >= 2.10 +pydantic >= 2.11.7 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer >= 0.10.11, < 0.11 +lm-format-enforcer == 0.11.3 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" -outlines_core == 0.2.10 +outlines_core == 0.2.10 ; platform_machine != "s390x" +outlines == 0.1.11 ; platform_machine == "s390x" # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 @@ -38,7 +39,7 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.10.2 # required for compressed-tensors +compressed-tensors == 0.11.0 # required for compressed-tensors depyf==0.19.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 6860275acab6f..f4b95b72898cc 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -1,8 +1,8 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.60.0; python_version == '3.9' and platform_machine != "s390x" # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding +numba == 0.61.2; python_version > '3.9' and platform_machine != "s390x" # Dependencies for CPUs packaging>=24.2 diff --git a/requirements/docs.txt b/requirements/docs.txt index a24b9c7e924bf..d1c546398780a 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -7,27 +7,12 @@ mkdocs-awesome-nav mkdocs-glightbox mkdocs-git-revision-date-localized-plugin mkdocs-minify-plugin -python-markdown-math regex ruff # Required for argparse hook only -f https://download.pytorch.org/whl/cpu cachetools -cbor2 -cloudpickle -fastapi msgspec -openai -openai-harmony -partial-json-parser -pillow -psutil -pybase64 pydantic -setproctitle torch -transformers -zmq -uvloop -prometheus-client diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 491fa06259631..a529bf4504e40 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -27,7 +27,7 @@ mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test -lm-eval[api]==0.4.8 # required for model evaluation test +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb>=1.38.11, <2 # required for mteb test transformers==4.52.4 tokenizers==0.21.1 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 94201543cd4f3..cbae9bbb8a9b3 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -6,7 +6,7 @@ torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 -triton==3.2 +triton==3.3.0 cmake>=3.26.1,<4 packaging>=24.2 setuptools>=77.0.3,<80.0.0 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 7038c9024c6b6..c3bb65b70a0b8 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 -conch-triton-kernels==1.2.1 +conch-triton-kernels==1.2.1 \ No newline at end of file diff --git a/requirements/test.in b/requirements/test.in index 6652bfdfe66c9..92c577c501632 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -32,9 +32,10 @@ num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test -lm-eval[api]==0.4.8 # required for model evaluation test +# TODO: Use lm-eval[api]==0.4.10 once released +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test -transformers==4.55.0 +transformers==4.55.2 tokenizers==0.21.1 schemathesis>=3.39.15 # Required for openai schema test. # quantization @@ -53,3 +54,4 @@ runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 terratorch==1.1rc2 # required for PrithviMAE test +decord==0.6.0 diff --git a/requirements/test.txt b/requirements/test.txt index ff9886a315976..0c27c9bb67e82 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -156,6 +156,8 @@ datasets==3.0.2 # mteb decorator==5.1.1 # via librosa +decord==0.6.0 + # via -r requirements/test.in dill==0.3.8 # via # datasets @@ -408,7 +410,7 @@ lightning-utilities==0.14.3 # torchmetrics llvmlite==0.44.0 # via numba -lm-eval==0.4.8 +lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # via -r requirements/test.in lxml==5.3.0 # via @@ -493,6 +495,7 @@ numpy==1.26.4 # contourpy # cupy-cuda12x # datasets + # decord # einx # encodec # evaluate @@ -742,7 +745,7 @@ pycparser==2.22 # via cffi pycryptodomex==3.22.0 # via blobfile -pydantic==2.11.5 +pydantic==2.11.7 # via # -r requirements/test.in # albumentations @@ -1139,7 +1142,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.55.0 +transformers==4.55.2 # via # -r requirements/test.in # genai-perf diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 7bb77c4a99636..7ea239b48ea26 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -11,6 +11,7 @@ ray[default] ray[data] setuptools==78.1.0 nixl==0.3.0 +tpu_info==0.4.0 # Install torch_xla --pre diff --git a/setup.py b/setup.py index 919300e143c1e..ffe8ec4e79af7 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,8 @@ MAIN_CUDA_VERSION = "12.8" def is_sccache_available() -> bool: - return which("sccache") is not None + return which("sccache") is not None and \ + not bool(int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))) def is_ccache_available() -> bool: @@ -642,16 +643,25 @@ if envs.VLLM_USE_PRECOMPILED: if wheel_location is not None: wheel_url = wheel_location else: + import platform + arch = platform.machine() + if arch == "x86_64": + wheel_tag = "manylinux1_x86_64" + elif arch == "aarch64": + wheel_tag = "manylinux2014_aarch64" + else: + raise ValueError(f"Unsupported architecture: {arch}") base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch() - wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" + nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" from urllib.request import urlopen try: with urlopen(wheel_url) as resp: if resp.status != 200: - wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + wheel_url = nightly_wheel_url except Exception as e: print(f"[warn] Falling back to nightly wheel: {e}") - wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + wheel_url = nightly_wheel_url patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( wheel_url) @@ -684,7 +694,9 @@ setup( "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.2.11"], + "flashinfer": ["flashinfer-python==0.2.14.post1"], + # Optional deps for AMD FP4 quantization support + "petit-kernel": ["petit-kernel"], }, cmdclass=cmdclass, package_data=package_data, diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 13ddf035a55e0..a3b09cc817917 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -12,7 +12,6 @@ import pytest import torch from vllm import LLM, envs -from vllm.platforms import current_platform from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from ..conftest import HfRunner, VllmRunner @@ -78,11 +77,7 @@ def test_models( "VLLM_USE_V1") and envs.VLLM_USE_V1: pytest.skip("enable_prompt_embeds is not supported in v1.") - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - - if backend in ("XFORMERS", - "FLASHINFER") and model == "google/gemma-2-2b-it": + if backend == "XFORMERS" and model == "google/gemma-2-2b-it": pytest.skip( f"{backend} does not support gemma2 with full context length.") @@ -141,8 +136,6 @@ def test_models( ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("distilbert/distilgpt2", "ray", "", "A100", {}), ("distilbert/distilgpt2", "mp", "", "A100", {}), - ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100", {}), - ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", {}), ]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models_distributed( diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py deleted file mode 100644 index 4816b76996fc8..0000000000000 --- a/tests/basic_correctness/test_chunked_prefill.py +++ /dev/null @@ -1,296 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the outputs of HF and vLLM when using greedy sampling. - -It tests chunked prefill. Chunked prefill can be enabled by -enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, -prefill requests are chunked. - -Run `pytest tests/models/test_chunked_prefill.py`. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR - -from ..models.utils import check_logprobs_close, check_outputs_equal -from ..utils import multi_gpu_test - -if TYPE_CHECKING: - from .conftest import HfRunner, VllmRunner - -MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-3.2-1B-Instruct", -] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the file. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -@pytest.mark.parametrize("enforce_eager", [False, True]) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("attention_backend", [ - pytest.param("FLASHINFER", - marks=pytest.mark.skipif( - current_platform.is_rocm(), - reason="FLASHINFER isn't supported on ROCm")), - "FLASH_ATTN" -]) -def test_models( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, - enforce_eager: bool, - tensor_parallel_size: int, - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Checks exact match decode between huggingface model and vllm runner with - chunked prefill. - """ - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("attention_backend", [ - pytest.param("FLASHINFER", - marks=pytest.mark.skipif( - current_platform.is_rocm(), - reason="FLASHINFER isn't supported on ROCm")), - "FLASH_ATTN" -]) -def test_models_distributed( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - distributed_executor_backend: str, - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - if (model == "meta-llama/Llama-3.2-1B-Instruct" - and distributed_executor_backend == "ray"): - # test Ray Compiled Graph - m.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") - m.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") - - dtype = "half" - max_tokens = 5 - chunked_prefill_token_size = 16 - - # Add a chunked prefill config. - max_num_seqs = min(chunked_prefill_token_size, 256) - assert chunked_prefill_token_size != -1 - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with - # fork method (the default method). - - with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - max_num_seqs=max_num_seqs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy( - example_prompts, - max_tokens, - ) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize( - "kv_cache_dtype,model", - [("fp8_e4m3", - "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")]) -# Due to low-precision numerical divergence, we only test logprob of 4 tokens -@pytest.mark.parametrize("max_tokens", [4]) -@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16]) -@pytest.mark.parametrize("enforce_eager", [False, True]) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -# Due to low-precision numerical divergence, this test is too sensitive to -# the async postprocessor -@pytest.mark.parametrize("disable_async_output_proc", [True]) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="machete_prepack_B isn't supported on ROCm") -def test_models_with_fp8_kv_cache( - vllm_runner: VllmRunner, - example_prompts, - kv_cache_dtype: str, - model: str, - max_tokens: int, - chunked_prefill_token_size: int, - enforce_eager: bool, - tensor_parallel_size: int, - disable_async_output_proc: bool, -) -> None: - """ - Check output logprobs match between no_chunked_prefill and chunked_prefill - with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py, - so here we only check chunked prefill. - """ - NUM_LOG_PROBS = 8 - - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - with vllm_runner( - model, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, - ) as vllm_model: - no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) - - with vllm_runner( - model, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, - ) as vllm_model: - chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) - - check_logprobs_close( - outputs_0_lst=no_chunked_prefill_outputs, - outputs_1_lst=chunked_prefill_outputs, - name_0="no_chunked_prefill", - name_1="chunked_prefill", - ) - - -@pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("enforce_eager", [False]) -@pytest.mark.parametrize("chunk_size", [30, 32]) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("dtype", ["half"]) -def test_with_prefix_caching( - vllm_runner: VllmRunner, - max_tokens: int, - enforce_eager: bool, - chunk_size: int, - tensor_parallel_size: int, - dtype: str, -) -> None: - """ - Checks exact match decode with and without prefix caching - with chunked prefill enabled. - """ - model = "meta-llama/Llama-3.2-1B-Instruct" - # The common prompt has 142 tokens with Llama-2 tokenizer. - common_prompt = "You are a helpful AI assistant " * 20 - unique_prompts = [ - "Question", # Warmup - "Question", # Fully cached - "Another question", # Partial cached - ] - full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts] - - max_num_batched_tokens = max_num_seqs = chunk_size - outputs = {} # type: ignore - for enable in (True, False): - with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - enable_prefix_caching=enable, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - ) as vllm_model: - outputs[enable] = [] - for prompt in full_prompts: - outputs[enable] += vllm_model.generate_greedy( - [prompt], - max_tokens, - ) - - check_outputs_equal( - outputs_0_lst=outputs[False], - outputs_1_lst=outputs[True], - name_0="w/o prefix caching", - name_1="with prefix caching", - ) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 34f9389c82a9b..f3ad680b72b55 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -177,3 +177,34 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): # cmp output assert output[0].outputs[0].text == output3[0].outputs[0].text + + +@create_new_process_for_each_test() +def test_deep_sleep(): + model = "Qwen/Qwen3-0.6B" + free, total = torch.cuda.mem_get_info() + used_bytes_baseline = total - free # in case other process is running + llm = LLM(model, enable_sleep_mode=True) + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) + + # Put the engine to deep sleep + llm.sleep(level=2) + + free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline + assert used_bytes < 3 * GiB_bytes + + llm.wake_up(tags=["weights"]) + llm.collective_rpc("reload_weights") + free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline + assert used_bytes < 4 * GiB_bytes + + # now allocate kv cache and cuda graph memory + llm.wake_up(tags=["kv_cache"]) + output2 = llm.generate(prompt, sampling_params) + + # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py new file mode 100644 index 0000000000000..26cae369cdd5d --- /dev/null +++ b/tests/benchmarks/test_random_dataset.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +from typing import Any, NamedTuple, Optional, cast + +import numpy as np +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset, + SampleRequest) + + +@pytest.fixture(scope="session") +def hf_tokenizer() -> PreTrainedTokenizerBase: + # Use a small, commonly available tokenizer + return AutoTokenizer.from_pretrained("gpt2") + + +class Params(NamedTuple): + num_requests: int + prefix_len: int + range_ratio: float + input_len: int + output_len: int + + +@pytest.fixture(scope="session") +def random_dataset_params() -> Params: + return Params(num_requests=16, + prefix_len=7, + range_ratio=0.3, + input_len=50, + output_len=20) + + +def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: + """Project a SampleRequest into a comparable tuple.""" + return (req.prompt, req.prompt_len, req.expected_output_len) + + +def _collect_samples(dataset: RandomDataset, + tokenizer: PreTrainedTokenizerBase, + num_requests: int = 16, + prefix_len: int = 7, + range_ratio: float = 0.3, + input_len: int = 50, + output_len: int = 20) -> list[tuple[str, int, int]]: + samples = dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + ) + return [_fingerprint_sample(s) for s in samples] + + +@pytest.mark.benchmark +def test_random_dataset_same_seed( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Same seed should yield identical outputs, even if global RNGs change. + + This guards against accidental reliance on Python's random or np.random + in RandomDataset after moving to numpy.default_rng. + """ + p = random_dataset_params + common_seed = 123 + dataset_a = RandomDataset(random_seed=common_seed) + dataset_b = RandomDataset(random_seed=common_seed) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + # Perturb global RNG state to ensure isolation + random.seed(999) + _ = [random.random() for _ in range(100)] + np.random.seed(888) + _ = [np.random.random() for _ in range(100)] + + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a == b + +@pytest.mark.benchmark +def test_random_dataset_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Different seeds should change outputs with overwhelming likelihood.""" + p = random_dataset_params + seed_a = 0 + dataset_a = RandomDataset(random_seed=seed_a) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + seed_b = 999 + dataset_b = RandomDataset(random_seed=seed_b) + # Perturb global RNG with same seed as dataset_a to ensure isolation + random.seed(seed_a) + np.random.seed(seed_a) + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a != b + + +# ----------------------------- +# RandomMultiModalDataset tests +# ----------------------------- + +def _mm_fingerprint_sample( + req: SampleRequest, +) -> tuple[str, int, int, int, list[str]]: + """Create a compact fingerprint for multimodal samples. + + Includes: + - prompt string + - prompt_len + - expected_output_len + - count of multimodal items + - per-item type and URL prefix (e.g., 'data:image/jpeg;base64,') + """ + items = req.multi_modal_data or [] + item_prefixes: list[str] = [] + for it in items: + if isinstance(it, dict) and it.get("type") == "image_url": + url = it.get("image_url", {}).get("url", "") + # Only keep a short identifying prefix to avoid huge strings + item_prefixes.append(f"image:{url[:22]}") + elif isinstance(it, dict) and it.get("type") == "video_url": + url = it.get("video_url", {}).get("url", "") + item_prefixes.append(f"video:{url[:22]}") + else: + item_prefixes.append("unknown:") + return (req.prompt, req.prompt_len, req.expected_output_len, len(items), + item_prefixes) + + +def _collect_mm_samples( + dataset: RandomMultiModalDataset, + tokenizer: PreTrainedTokenizerBase, + *, + num_requests: int = 8, + prefix_len: int = 3, + range_ratio: float = 0.0, + input_len: int = 20, + output_len: int = 5, + base_items_per_request: int = 2, + num_mm_items_range_ratio: float = 0.0, + limit_mm_per_prompt: Optional[dict[str, int]] = None, + bucket_config: Optional[dict[tuple[int, int, int], float]] = None, + enable_multimodal_chat: bool = False, +) -> list[SampleRequest]: + if limit_mm_per_prompt is None: + limit_mm_per_prompt = {"image": 5, "video": 0} + if bucket_config is None: + bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5} + return dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + base_items_per_request=base_items_per_request, + num_mm_items_range_ratio=num_mm_items_range_ratio, + limit_mm_per_prompt=limit_mm_per_prompt, + bucket_config=bucket_config, + enable_multimodal_chat=enable_multimodal_chat, + ) + + +@pytest.mark.benchmark +def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None: + seed = 42 + ds_a = RandomMultiModalDataset(random_seed=seed) + ds_b = RandomMultiModalDataset(random_seed=seed) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa == fb + + +@pytest.mark.benchmark +def test_random_mm_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds_a = RandomMultiModalDataset(random_seed=0) + ds_b = RandomMultiModalDataset(random_seed=999) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa != fb + +@pytest.mark.benchmark +def test_random_mm_respects_limits( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Requesting 3 items with a per-prompt limit of 1 should error per current + # design (dataset refuses to silently clamp below the requested baseline). + with pytest.raises(ValueError): + _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=12, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + + +@pytest.mark.benchmark +def test_random_mm_zero_prob_entries_are_removed( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Second bucket has zero probability and should be ignored after + # normalization + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=6, + base_items_per_request=2, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 10, "video": 0}, + bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0}, + ) + for s in samples: + assert isinstance(s.multi_modal_data, list) + typed_mm = cast(list[dict[str, Any]], s.multi_modal_data) + for it in typed_mm: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=0, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 5, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + for s in samples: + assert s.multi_modal_data == [] + +@pytest.mark.benchmark +def test_random_mm_num_items_per_prompt( + hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Fixed number of images per prompt + # set num_mm_items_range_ratio to 0.0 + # TODO: modify video values when video sampling is implemented + samples_fixed_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 3, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with 3 mm items per prompt + assert len(samples_fixed_items) == 5 + for s in samples_fixed_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) == 3 + for it in mm_data: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_bucket_config_not_mutated( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + + ds = RandomMultiModalDataset(random_seed=0) + # This bucket config is not normalized to sum to 1 + # and has more buckets than requested images + original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3} + # Keep a snapshot to compare after sampling + snapshot = dict(original) + + _ = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=4, + base_items_per_request=1, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config=original, + ) + + # Ensure the original dict content is unchanged + assert original == snapshot + + + # Vary number of mm items per prompt + # set num_mm_items_range_ratio to 0.5 + samples_varying_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=2, + num_mm_items_range_ratio=0.5, + limit_mm_per_prompt={"image": 4, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with less than 4 mm items per prompt + # but at least 1 mm item per prompt + assert len(samples_varying_items) == 5 + for s in samples_varying_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) <= 4 + assert len(mm_data) >= 1 + for it in mm_data: + assert it.get("type") == "image_url" diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index e460d70951786..f5e2d9ddb7528 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, support_torch_compile) -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) -from vllm.envs import VLLM_USE_V1 -from vllm.forward_context import set_forward_context +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -164,104 +163,34 @@ class SimpleModelWithTwoGraphs(ParentModel): return x -def test_ignore_torch_compile_decorator(): - assert VLLM_USE_V1 - - # piecewise - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) - - @support_torch_compile - class A(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + x - attn_output = torch.empty_like(x) - torch.ops.silly.attention(x, x, x, attn_output) - x = attn_output - x = x * 3 - return x - - @ignore_torch_compile - class B(A): - ... - - @support_torch_compile - class C(B): - ... - - with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() - - # A has support_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - # first run is for compile - mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - # run cudagraph captured sizes - mod_A(torch.randn(2, MLP_SIZE).cuda()) - mod_A(torch.randn(1, MLP_SIZE).cuda()) - - with set_current_vllm_config(vllm_config): - mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() - - # B's ignore_torch_compile should override A's support_torch_compile - with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, - ), set_forward_context({}, vllm_config=vllm_config): - mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - mod_B(torch.randn(2, MLP_SIZE).cuda()) - mod_B(torch.randn(1, MLP_SIZE).cuda()) - - with set_current_vllm_config(vllm_config): - mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() - - # C's support_torch_compile should override B's ignore_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - mod_C(torch.randn(2, MLP_SIZE).cuda()) - mod_C(torch.randn(1, MLP_SIZE).cuda()) - - @torch.inference_mode -def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor): +def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor, + cudagraph_runtime_mode: CUDAGraphMode): with set_forward_context({}, vllm_config=vllm_config): - # First run is for compile + # warmup for the model with cudagraph_mode NONE model(inputs) - # Run CUDAGraph captured sizes - model(inputs[:2]) - model(inputs[:1]) + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(inputs[:2]) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(inputs[:1]) - output = model(inputs[:2]) + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(inputs[:2]) output = output.cpu() return output.cpu() @@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): splitting_ops=["silly.attention"], cudagraph_capture_sizes=[1, 2], )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_cudagraph_captured=8, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # no compile or cudagraph vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.NO_COMPILATION, )) + cudagraph_runtime_mode = CUDAGraphMode.NONE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # piecewise compile without CUDA graph vllm_config = VllmConfig(compilation_config=CompilationConfig( @@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): use_cudagraph=False, splitting_ops=["silly.attention"], )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_backend_compilations=4, num_cudagraph_captured=0, # no cudagraph captured ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # Generally don't expect outputs with and without inductor # to be bitwise equivalent diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index cf715cd03222c..422cb94b036ca 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -34,7 +34,7 @@ class TestSetting: model_args=["--max-model-len", "2048"], pp_size=2, tp_size=2, - attn_backend="FLASHINFER", + attn_backend="FLASH_ATTN", method="generate", fullgraph=True, ), diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py new file mode 100644 index 0000000000000..51f8ddd566d56 --- /dev/null +++ b/tests/compile/test_decorator.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import (ignore_torch_compile, + support_torch_compile) +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + CUDAGraphMode, VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 32 +MLP_SIZE = 128 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, + cudagraph_runtime_mode: CUDAGraphMode): + with set_forward_context({}, vllm_config=vllm_config): + # warmup for the model with cudagraph_mode NONE + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(torch.randn(2, MLP_SIZE).cuda()) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(torch.randn(1, MLP_SIZE).cuda()) + + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(torch.randn(2, MLP_SIZE).cuda()) + + output = output.cpu() + return output.cpu() + + +def test_ignore_torch_compile_decorator(): + # piecewise + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + @support_torch_compile + class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + @ignore_torch_compile + class B(A): + ... + + @support_torch_compile + class C(B): + ... + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + with set_current_vllm_config(vllm_config): + mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() + + # B's ignore_torch_compile should override A's support_torch_compile + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, + ): + run_model(vllm_config, mod_B, cudagraph_runtime_mode) + + with set_current_vllm_config(vllm_config): + mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() + + # C's support_torch_compile should override B's ignore_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_C, cudagraph_runtime_mode) + + +# Only enable torch.compile if +# vllm_config.cache_config.kv_sharing_fast_prefill=True +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class B(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x + x + return x + + +# Only enable torch.compile if +# vllm_config.cache_config.kv_sharing_fast_prefill=False +@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. + cache_config.kv_sharing_fast_prefill) +class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) + self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mod1(x) + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = self.mod2(x) + return x + + +def test_conditional_compile_enable_if(): + vllm_config = VllmConfig(cache_config=CacheConfig( + kv_sharing_fast_prefill=True, ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile but enable_if fn returns False + # enalbe_if will be True for B, so we expect mod1 and mod2 + # to be compiled + with compilation_counter.expect( + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + # 3 piecewise graphs per instance of B() + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + # Set kv_sharing_fast_prefill=False + # which will cause A to be compiled and B to not be compiled + vllm_config = VllmConfig(cache_config=CacheConfig( + kv_sharing_fast_prefill=False, ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=7, + # 3 attn ops and 4 non-attn ops + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 72f962ed7484c..84178344a5f36 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -31,10 +31,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): ] if all: - if is_quant_method_supported("aqlm"): - TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { - "quantization": "aqlm" - })) # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 @@ -57,12 +53,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): "quantization": "gptq_marlin_24" })) - if is_quant_method_supported("marlin"): - TEST_MODELS.append( - ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { - "quantization": "marlin" - })) - if not current_platform.is_rocm() and is_quant_method_supported("awq"): TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { "quantization": "AWQ" diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index aade29b99de7e..0c7e6fbccf20c 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -8,11 +8,12 @@ import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, - kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.compilation.fusion import FUSED_OPS, FusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from .backend import TestBackend diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4a3820e20fd89..c4229f93464ac 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -7,13 +7,15 @@ import torch import vllm.envs as envs import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass, GroupShape, QuantKey) + FusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) + Fp8LinearOp, maybe_create_device_identity) from vllm.platforms import current_platform from .backend import TestBackend @@ -24,16 +26,14 @@ FP8_DTYPE = current_platform.fp8_dtype() class TestModel(torch.nn.Module): def __init__(self, hidden_size: int, eps: float, static: bool, - cutlass_fp8_enabled: bool, *args, **kwargs): + force_fp8_e4m3fnuz: bool, *args, **kwargs): super().__init__(*args, **kwargs) - self.cutlass_fp8_enabled = cutlass_fp8_enabled + self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN - self.key = QuantKey(dtype=FP8_DTYPE, - static=static, - group_shape=group_shape, - symmetric=True) + quant_scale = ScaleDesc(torch.float32, static, group_shape) + self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] else: @@ -43,7 +43,7 @@ class TestModel(torch.nn.Module): for _ in range(2) ] self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, act_quant_static=static, act_quant_group_shape=group_shape, ) @@ -81,12 +81,11 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) +@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - cutlass_fp8_enabled): + force_fp8_e4m3fnuz): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -103,7 +102,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) - model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) + model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 4c3cf6c2a10cf..dd31e0db1f59f 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -148,7 +148,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [16]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 70750eb9ac4ee..dba668cfa16a6 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy from typing import Optional import pytest @@ -7,13 +8,29 @@ import torch._dynamo from tests.compile.backend import TestBackend from tests.models.utils import check_outputs_equal +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata) from vllm import LLM, SamplingParams -from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +from vllm.attention import Attention +from vllm.attention.selector import global_force_attn_backend_context_manager +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + ModelConfig, PassConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) from vllm.platforms import current_platform +from vllm.v1.kv_cache_interface import AttentionSpec + +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 # globals needed for string-import custom Dynamo backend field backend: Optional[TestBackend] = None @@ -90,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # check support attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key.dtype, - quant_key.static, - quant_key.group_shape) + layer.impl.fused_output_quant_supported(quant_key) for key, layer in compile_config.static_forward_context.items() ] @@ -132,3 +147,309 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # Reset backend to make sure llm2 gets released backend = None + + +class AttentionQuantPatternModel(torch.nn.Module): + """Base model for AttentionQuantPattern fusion.""" + + def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, + kv_cache_dtype: torch.dtype, device: torch.device, + vllm_config: VllmConfig, **kwargs): + super().__init__() + self.num_qo_heads = num_qo_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.kv_cache_dtype = kv_cache_dtype + self.device = device + self.vllm_config = vllm_config + + self.attn = Attention( + num_heads=self.num_qo_heads, + head_size=self.head_size, + scale=1.0 / (self.head_size**0.5), + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + prefix="model.layers.0.self_attn.attn", + ) + + self.block_size = 16 + + # Initialize attn MetadataBuilder + self.builder = self.attn.attn_backend.get_builder_cls()( + kv_cache_spec=AttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_dtype, + use_mla=False, + ), + layer_names=[self.attn.layer_name], + vllm_config=self.vllm_config, + device=self.device, + ) + + def build_attn_metadata(self, batch_size: int): + """Initialize attention metadata.""" + + # Create common attn metadata + batch_spec = BatchSpec(seq_lens=[1] * batch_size, + query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + self.block_size, + self.device, + arange_block_indices=True) + + max_blocks = (max(batch_spec.seq_lens) + self.block_size - + 1) // self.block_size + num_blocks = batch_size * max_blocks + + # Create dummy KV cache for FlashInfer TRTLLM + # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = torch.zeros(num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + self.attn.kv_cache = [kv_cache] + + # Build attn metadata + self.attn_metadata = self.builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata) + + return self.attn_metadata + + +class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionFp8StaticQuantPattern fusion.""" + + quant_key = kFp8StaticTensorSym + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.quant_key.scale.static, + act_quant_group_shape=self.quant_key.scale.group_shape) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randn(hidden_size, hidden_size).to( + dtype=FP8_DTYPE, device=self.device).t(), + "wscale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + return self.fp8_linear.apply(input=attn_output, + weight=self.w["weight"], + weight_scale=self.w["wscale"], + input_scale=self.w["scale"]) + + +class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionNvfp4QuantPattern fusion.""" + + quant_key = kNvfp4Quant + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randint(256, (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE, + device=self.device), + "wscale_swizzled": + torch.randn(hidden_size, hidden_size // 16).to( + dtype=FP8_DTYPE, device=self.device), + "wscale": + torch.tensor([500], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([0.002], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + quant_output, output_block_scale = scaled_fp4_quant( + attn_output, 1 / self.w["scale"]) + return cutlass_scaled_fp4_mm(a=quant_output, + b=self.w["weight"], + block_scale_a=output_block_scale, + block_scale_b=self.w["wscale_swizzled"], + alpha=self.w["scale"] * self.w["wscale"], + out_dtype=attn_output.dtype) + + +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("batch_size", [7, 256, 533]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_name, model_class", + [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel), + ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel)]) +@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") +@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), + reason="Only test on SM100(Blackwell)") +def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, + head_size: int, batch_size: int, + dtype: torch.dtype, model_name: str, + model_class: type[AttentionQuantPatternModel], + backend: _Backend, monkeypatch, dist_init): + """Test AttentionStaticQuantPattern fusion pass""" + + monkeypatch.setenv("VLLM_USE_V1", "1") + + device = torch.device("cuda:0") + torch.manual_seed(42) + + vllm_config = VllmConfig( + model_config=ModelConfig( + model=model_name, + max_model_len=2048, + ), + scheduler_config=SchedulerConfig(max_num_seqs=1024), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+quant_fp8"], + ), + cache_config=CacheConfig(cache_dtype="fp8")) + + # Create test inputs + q = torch.randn(batch_size, + num_qo_heads * head_size, + dtype=dtype, + device=device) + k = torch.randn(batch_size, + num_kv_heads * head_size, + dtype=dtype, + device=device) + v = torch.randn(batch_size, + num_kv_heads * head_size, + dtype=dtype, + device=device) + + # Mark first dimension as dynamic for realistic testing + torch._dynamo.mark_dynamic(q, 0) + torch._dynamo.mark_dynamic(k, 0) + torch._dynamo.mark_dynamic(v, 0) + + # Run model directly without compilation and fusion + vllm_config_unfused = copy.deepcopy(vllm_config) + with set_current_vllm_config(vllm_config_unfused), set_forward_context( + attn_metadata=None, vllm_config=vllm_config_unfused + ), global_force_attn_backend_context_manager(backend): + model_unfused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config_unfused) + model_unfused = model_unfused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_unfused.build_attn_metadata( + batch_size) + + # Run model directly without compilation and fusion + result_unfused = model_unfused(q, k, v) + + # Run model with attn fusion enabled + vllm_config.compilation_config.pass_config = PassConfig( + enable_attn_fusion=True, enable_noop=True) + with set_current_vllm_config(vllm_config), set_forward_context( + attn_metadata=None, vllm_config=vllm_config + ), global_force_attn_backend_context_manager(backend): + model_fused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config, + w=model_unfused.w) + model_fused = model_fused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + + # Create test backend with fusion passes enabled + noop_pass = NoOpEliminationPass(vllm_config) + attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw + ) + test_backend = TestBackend(noop_pass, attn_pass) + + # Compile model with fusion enabled + model_compiled = torch.compile(model_fused, + backend=test_backend, + fullgraph=True) + assert model_compiled.attn._o_scale_float is None + result_fused_1 = model_compiled(q, k, v) + + # After the 1st round of the forward pass, output quant scale should be + # loaded into the attn layer's _o_scale_float, the 2nd round should + # reuse the loaded _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v) + assert model_compiled.attn._o_scale_float is not None + + # Check attn fusion support + quant_key = model_class.quant_key + attn_fusion_supported = [ + layer.impl.fused_output_quant_supported(quant_key) for key, layer in + vllm_config.compilation_config.static_forward_context.items() + ] + if any(attn_fusion_supported): + # Check quantization ops in the graph before and after fusion + test_backend.check_before_ops([QUANT_OPS[quant_key]], + fully_replaced=True) + + # Check attention ops in the graph before and after fusion + attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, + test_backend.graph_post_pass)) + + assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" + assert len(attn_nodes_pre) == len(attn_nodes_post), \ + "Should have same number of attention nodes before and after fusion" + assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \ + "Attention should not have output_scale before fusion" + assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ + "Attention should have output_scale after fusion" + + assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale before fusion" + if quant_key.dtype == FP8_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale after FP8 fusion" + elif quant_key.dtype == FP4_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ + "Attention should have output_block_scale after FP4 fusion" # noqa: E501 + + # Check that results are closed + torch.testing.assert_close(result_unfused, + result_fused_1, + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(result_unfused, + result_fused_2, + atol=1e-2, + rtol=1e-2) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index a6baa97fe6990..fb9f9dde22799 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -104,8 +104,7 @@ class TestQuantModel(torch.nn.Module): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True, - use_per_token_if_dynamic=False) + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 5351a3cf35ba5..0e1059e654479 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp) + Fp8LinearOp) from vllm.platforms import current_platform from .backend import TestBackend @@ -20,7 +20,7 @@ from .backend import TestBackend class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, + def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args, **kwargs): super().__init__(*args, **kwargs) self.silu_and_mul = SiluAndMul() @@ -32,7 +32,7 @@ class TestModel(torch.nn.Module): hidden_size).to(dtype=current_platform.fp8_dtype()).t()) self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR, ) @@ -48,12 +48,11 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("hidden_size", [64]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) +@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, - cutlass_fp8_enabled): + force_fp8_e4m3fnuz): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -64,7 +63,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = TestModel(hidden_size, cutlass_fp8_enabled) + model = TestModel(hidden_size, force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size * 2) diff --git a/tests/conftest.py b/tests/conftest.py index 3f3790cab8d35..f8bfdfc8e6259 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -456,7 +456,15 @@ class HfRunner: outputs = [] for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) - logits = output.logits.softmax(dim=-1)[0].tolist() + + problem_type = getattr(self.config, "problem_type", "") + + if problem_type == "regression": + logits = output.logits[0].tolist() + elif problem_type == "multi_label_classification": + logits = output.logits.sigmoid()[0].tolist() + else: + logits = output.logits.softmax(dim=-1)[0].tolist() outputs.append(logits) return outputs @@ -1014,15 +1022,17 @@ class VllmRunner: images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, + concurrency_limit: Optional[int] = None, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - outputs = self.llm.beam_search( - inputs, - BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) + outputs = self.llm.beam_search(inputs, + BeamSearchParams(beam_width=beam_width, + max_tokens=max_tokens), + concurrency_limit=concurrency_limit) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index 4d67eea2264b2..27fe27a880e3d 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -32,7 +32,7 @@ BLOCK_SIZE = 16 @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, batch_size, seed, backend, monkeypatch): """ @@ -43,8 +43,6 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, Additionally, we compare the results of the v1 and v2 managers. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") if backend == "XFORMERS" and current_platform.is_rocm(): pytest.skip("Xformers does not support ROCm/HIP.") @@ -96,7 +94,7 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, backend, monkeypatch): """ @@ -107,8 +105,6 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, The results with and without chunked prefill are not the same due to numerical instabilities. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") if backend == "XFORMERS" and current_platform.is_rocm(): pytest.skip("Xformers does not support ROCm/HIP.") override_backend_env_variable(monkeypatch, backend) diff --git a/tests/detokenizer/test_min_tokens.py b/tests/detokenizer/test_min_tokens.py new file mode 100644 index 0000000000000..887e83342536e --- /dev/null +++ b/tests/detokenizer/test_min_tokens.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from vllm import SamplingParams +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer + +PROMPT = "Hello, my name is Lee, and I'm a student in the " + \ + "college of engineering" + + +@pytest.mark.parametrize("min_tokens,stop,truth", [ + (0, None, " is Lee, and I'm a student in the college of engineering"), + (0, "e", " is L"), + (5, "e", " is Lee, and I'm a stud"), +]) +def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): + """Test for a specific min_tokens and stop. + + See https://github.com/vllm-project/vllm/pull/22014 + """ + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + all_prompt_ids = tokenizer(PROMPT, add_special_tokens=False).input_ids + + # The prompt is "Hello, my name is" + prompt_token_ids = all_prompt_ids[:4] + params = SamplingParams( + stop=stop, + min_tokens=min_tokens, + ) + request = EngineCoreRequest("", + prompt_token_ids, + None, + None, + None, + params, + None, + None, + 0.0, + None, + cache_salt=None, + data_parallel_rank=None) + + detokenizer = FastIncrementalDetokenizer(tokenizer, request) + + detokenizer.update(all_prompt_ids[4:], False) + assert detokenizer.output_text == truth diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index e2cb579e22dc4..8d84cc2d0ffe6 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -18,7 +18,8 @@ from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) -from ..utils import init_test_distributed_environment, multi_process_parallel +from ..utils import (init_test_distributed_environment, multi_gpu_test, + multi_process_parallel) @ray.remote(num_gpus=1, max_calls=1) @@ -226,8 +227,7 @@ def send_recv_test_worker( torch.testing.assert_close(test_tensor, recv_tensor) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("test_target", [ all_reduce_test_worker, all_gather_test_worker, @@ -241,8 +241,7 @@ def test_multi_process_tensor_parallel( multi_process_parallel(monkeypatch, tp_size, 1, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize( "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) @@ -254,8 +253,7 @@ def test_multi_process_pipeline_parallel( multi_process_parallel(monkeypatch, 1, pp_size, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@multi_gpu_test(num_gpus=4) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize("test_target", [ diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 12dd7c4222630..28150d7682378 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -233,6 +233,7 @@ MULTIMODAL_MODELS = { "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), "allenai/Molmo-7B-D-0924": PPTestSettings.fast(), "AIDC-AI/Ovis2-1B": PPTestSettings.fast(), + "AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(), "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(), "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(), diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index a027a9e37dd67..5ca65a0e8d2c9 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -17,7 +17,6 @@ if TYPE_CHECKING: ]) @pytest.mark.parametrize("ATTN_BACKEND", [ "FLASH_ATTN", - "FLASHINFER", ]) @create_new_process_for_each_test() def test_pp_cudagraph( diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py new file mode 100644 index 0000000000000..5a804a389123b --- /dev/null +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, + init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +def symm_mem_allreduce_worker(local_rank: int, world_size: int): + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + cuda_communicator = typing.cast(CudaCommunicator, + get_tp_group().device_communicator) + symm_mem_comm = cuda_communicator.symm_mem_comm + if symm_mem_comm is None or symm_mem_comm.disabled: + pytest.skip("SymmMemCommunicator is not available or disabled.") + + inp_direct_symm_mem = torch.randint(1, + 23, (test_size_elements, ), + dtype=dtype, + device=device) + if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): + pytest.skip( + "SymmMemCommunicator isn't used for this world and input size." + ) + + original_inp_direct_symm_mem = inp_direct_symm_mem.clone() + out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) + assert out_direct_symm_mem is not None + + group = get_tensor_model_parallel_group().device_group + dist.all_reduce(original_inp_direct_symm_mem, group=group) + torch.testing.assert_close(out_direct_symm_mem, + original_inp_direct_symm_mem, + atol=2.5, + rtol=0.1) + + # Test tensor_model_parallel_all_reduce which should use symm_mem + inp_tensor_parallel = torch.randint(-23, + 1, (test_size_elements, ), + dtype=dtype, + device=device) + original_inp_tensor_parallel = inp_tensor_parallel.clone() + out_tensor_parallel = tensor_model_parallel_all_reduce( + inp_tensor_parallel) + dist.all_reduce(original_inp_tensor_parallel, group=group) + torch.testing.assert_close(out_tensor_parallel, + original_inp_tensor_parallel, + atol=2.5, + rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") + + mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) + cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 97cf3b5ce8fcb..2cbfed98a577a 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -18,10 +18,9 @@ def text_llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -88,10 +87,9 @@ def vision_llm(): seed=0, ) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -158,10 +156,9 @@ def thinking_llm(): seed=0, ) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py index 71e76abcb7d2c..57705ff669075 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/llm/test_classify.py @@ -35,10 +35,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py index ba20d7b9548ef..485f04ed6d849 100644 --- a/tests/entrypoints/llm/test_embedding.py +++ b/tests/entrypoints/llm/test_embedding.py @@ -26,10 +26,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index b930f05bebd0f..cb54b16b0b044 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -5,11 +5,9 @@ import weakref import pytest -from vllm import LLM, PoolingParams, PoolingRequestOutput +from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import check_embeddings_close - MODEL_NAME = "intfloat/multilingual-e5-small" PROMPTS = [ @@ -48,57 +46,13 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() -def assert_outputs_match(o1: list[PoolingRequestOutput], - o2: list[PoolingRequestOutput]): - check_embeddings_close( - embeddings_0_lst=[o.outputs.data for o in o1], - embeddings_1_lst=[o.outputs.data for o in o2], - name_0="hf", - name_1="vllm", - tol=1e-2, - ) - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=prompt_token_ids, - pooling_params=pooling_params) - - v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, - pooling_params=pooling_params) - assert_outputs_match(v1_output, v2_output) - - -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, - pooling_params=pooling_params) - - v2_output = llm.encode( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], - pooling_params=pooling_params, - ) - assert_outputs_match(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_multiple_pooling_params(llm: LLM): pooling_params = [ diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 707891f6bdd8d..3bbbcc755d134 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -5,7 +5,7 @@ import weakref import pytest -from vllm import LLM, RequestOutput, SamplingParams +from vllm import LLM, SamplingParams from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "distilbert/distilgpt2" @@ -41,50 +41,13 @@ def llm(): gpu_memory_utilization=0.10, enforce_eager=True) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]): - assert [o.outputs for o in o1] == [o.outputs for o in o2] - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params) - - v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, - sampling_params=sampling_params) - - v2_output = llm.generate( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_multiple_sampling_params(llm: LLM): sampling_params = [ diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index b7d53e31fd71b..a04f195692e9b 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -48,10 +48,9 @@ def llm(request, monkeypatch_module): max_num_seqs=128, enforce_eager=True) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py index 361e2d0e1047f..de82cf8d40380 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/llm/test_reward.py @@ -36,10 +36,9 @@ def llm(): trust_remote_code=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index dd4eae0ccc06e..5a1339b2addf4 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -33,10 +33,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index a606eeab5887e..dd8d63ad319ac 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for HF_HUB_OFFLINE mode""" +import dataclasses import importlib import sys @@ -9,6 +10,7 @@ import urllib3 from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory +from vllm.engine.arg_utils import EngineArgs MODEL_CONFIGS = [ { @@ -108,3 +110,36 @@ def _re_import_modules(): # Error this test if reloading a module failed if reload_exception is not None: raise reload_exception + + +@pytest.mark.skip_global_cleanup +@pytest.mark.usefixtures("cache_models") +def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch): + # Set HF to offline mode and ensure we can still construct an LLM + with monkeypatch.context() as m: + try: + m.setenv("HF_HUB_OFFLINE", "1") + m.setenv("VLLM_NO_USAGE_STATS", "1") + + def disable_connect(*args, **kwargs): + raise RuntimeError("No http calls allowed") + + m.setattr( + urllib3.connection.HTTPConnection, + "connect", + disable_connect, + ) + m.setattr( + urllib3.connection.HTTPSConnection, + "connect", + disable_connect, + ) + # Need to re-import huggingface_hub + # and friends to setup offline mode + _re_import_modules() + engine_args = EngineArgs(model="facebook/opt-125m") + LLM(**dataclasses.asdict(engine_args)) + finally: + # Reset the environment after the test + # NB: Assuming tests are run in online mode + _re_import_modules() diff --git a/tests/entrypoints/openai/test_collective_rpc.py b/tests/entrypoints/openai/test_collective_rpc.py new file mode 100644 index 0000000000000..37c0b7a900ac4 --- /dev/null +++ b/tests/entrypoints/openai/test_collective_rpc.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +class TestWorkerExtension: + + def get_model_name(self) -> str: + """Test non-pydantic return type.""" + return MODEL_NAME + + def echo_args_kwargs(self, *args, **kwargs) -> dict[str, Any]: + """Echo back both args and kwargs.""" + return dict( + args=list(args), + kwargs=kwargs, + total_items=len(args) + len(kwargs), + ) + + def return_none(self, *args, **kwargs) -> None: + """Test method that does not return anything""" + return + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--worker-extension-cls", + "tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension", + ] + with RemoteOpenAIServer( + MODEL_NAME, + args, + env_dict={ + "VLLM_SERVER_DEV_MODE": "1", + "CUDA_VISIBLE_DEVICES": "0" + }, + ) as remote_server: + yield remote_server + + +def test_get_model_name(server): + """Test basic response""" + response = requests.post(server.url_for("collective_rpc"), + json={"method": "get_model_name"}) + assert response.status_code == 200 + results = response.json() + assert "results" in results + assert results["results"] == [MODEL_NAME] + + +def test_return_none(server): + """Test return none""" + response = requests.post(server.url_for("collective_rpc"), + json={"method": "return_none"}) + assert response.status_code == 200 + results = response.json() + assert results["results"] == [None] + + +def test_echo_args_kwargs(server): + """Test args, kwargs, and dict response""" + args = ["arg1", "arg2"] + kwargs = {"key1": "value1", "key2": "value2"} + response = requests.post(server.url_for("collective_rpc"), + json={ + "method": "echo_args_kwargs", + "args": args, + "kwargs": kwargs + }) + assert response.status_code == 200 + results = response.json() + result = results["results"][0] + assert result["args"] == args + assert result["kwargs"] == kwargs + assert result["total_items"] == len(args) + len(kwargs) diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index a5b081f861074..4ef5d4e8a699a 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "Qwen/Qwen3-0.6B" +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to find the weather for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "options": { + "$ref": "#/$defs/WeatherOptions", + "description": "Optional parameters for weather query", + }, + }, + "required": ["country", "unit"], + "$defs": { + "WeatherOptions": { + "title": "WeatherOptions", + "type": "object", + "additionalProperties": False, + "properties": { + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + "description": "Temperature unit", + "title": "Temperature Unit", + }, + "include_forecast": { + "type": "boolean", + "default": False, + "description": + "Whether to include a 24-hour forecast", + "title": "Include Forecast", + }, + "language": { + "type": "string", + "default": "zh-CN", + "description": "Language of the response", + "title": "Language", + "enum": ["zh-CN", "en-US", "ja-JP"], + }, + }, + }, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to get the forecast for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "days": { + "type": + "integer", + "description": + "Number of days to get the forecast for (1-7)", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "days", "unit"], + }, + }, + }, +] + +messages = [ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": + "user", + "content": + "Can you tell me what the current weather is in Berlin and the "\ + "forecast for the next 5 days, in fahrenheit?", + }, +] + @pytest.fixture(scope="module") def server(): # noqa: F811 @@ -27,6 +148,8 @@ def server(): # noqa: F811 "hermes", "--reasoning-parser", "qwen3", + "--gpu-memory-utilization", + "0.4" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -54,129 +177,6 @@ async def client(server): async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: Union[str, dict], enable_thinking: bool): - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - "options": { - "$ref": "#/$defs/WeatherOptions", - "description": - "Optional parameters for weather query", - }, - }, - "required": ["country", "unit"], - "$defs": { - "WeatherOptions": { - "title": "WeatherOptions", - "type": "object", - "additionalProperties": False, - "properties": { - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "default": "celsius", - "description": "Temperature unit", - "title": "Temperature Unit", - }, - "include_forecast": { - "type": "boolean", - "default": False, - "description": - "Whether to include a 24-hour forecast", - "title": "Include Forecast", - }, - "language": { - "type": "string", - "default": "zh-CN", - "description": "Language of the response", - "title": "Language", - "enum": ["zh-CN", "en-US", "ja-JP"], - }, - }, - }, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_forecast", - "description": "Get the weather forecast for a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["country", "days", "unit"], - }, - }, - }, - ] - - messages = [ - { - "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ - "forecast for the next 5 days, in fahrenheit?", - }, - ] if not stream: # Non-streaming test chat_completion = await client.chat.completions.create( @@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, output.extend(chunk.choices[0].delta.tool_calls) assert len(output) > 0 + + +@pytest.fixture(scope="module") +def k2_server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--enable-auto-tool-choice", + "--guided-decoding-backend", + "xgrammar", + "--tool-call-parser", + "hermes", + "--reasoning-parser", + "qwen3", + "--gpu-memory-utilization", + "0.4", + ] + # hack to test kimi_k2 tool use tool_id format. + # avoid error in is_deepseek_mla check by setting kv_lora_rank=null + with RemoteOpenAIServer(MODEL_NAME, + args, + override_hf_configs={ + "model_type": 'kimi_k2', + 'kv_lora_rank': None + }) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def k2_client(k2_server): + async with k2_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("tool_choice", ["required"]) +async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, + stream: bool, tool_choice: str): + + if not stream: + # Non-streaming test + chat_completion = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice) + assert chat_completion.choices[0].message.tool_calls is not None + assert len(chat_completion.choices[0].message.tool_calls) > 0 + assert chat_completion.choices[0].message.tool_calls[ + 0].id == 'functions.get_current_weather:0' + else: + # Streaming test + output_stream = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice, + stream=True) + + output = [] + async for chunk in output_stream: + if chunk.choices and chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) + for o in output: + assert o.id is None or o.id == 'functions.get_current_weather:0' diff --git a/tests/entrypoints/openai/test_default_mm_loras.py b/tests/entrypoints/openai/test_default_mm_loras.py index 372e9b1fecd42..b9c466a6fbeb6 100644 --- a/tests/entrypoints/openai/test_default_mm_loras.py +++ b/tests/entrypoints/openai/test_default_mm_loras.py @@ -48,7 +48,8 @@ def multimodal_server(): # noqa: F811 f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}", ] - with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args, + max_wait_seconds=480) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 9107d089834bb..ff2e7004ff9f8 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import asyncio import subprocess import sys import tempfile @@ -294,6 +294,99 @@ async def test_metrics_exist(server: RemoteOpenAIServer, assert metric in response.text +@pytest.mark.asyncio +async def test_abort_metrics_reset(server: RemoteOpenAIServer, + client: openai.AsyncClient, use_v1: bool): + + running_requests, waiting_requests, kv_cache_usage = ( + _get_running_metrics_from_api(server)) + + # Expect no running requests or kvcache usage + assert running_requests == 0 + assert waiting_requests == 0 + assert kv_cache_usage == 0.0 + + # Start some long-running requests that we can abort + tasks = [] + for _ in range(3): + task = asyncio.create_task( + client.completions.create( + model=MODEL_NAME, + prompt=_TOKENIZED_PROMPT, + max_tokens=100, # Long generation to give time to abort + temperature=0.0)) + tasks.append(task) + + # Wait a bit for requests to start processing + await asyncio.sleep(0.5) + + # Check that we have running requests + running_requests, waiting_requests, kv_cache_usage = ( + _get_running_metrics_from_api(server)) + + # Expect running requests and kvcache usage + assert running_requests > 0 + assert kv_cache_usage > 0 + + # Cancel all tasks to abort the requests + for task in tasks: + task.cancel() + + # Wait for cancellations to be processed + await asyncio.sleep(1.0) + + # Check that metrics have reset to zero + response = requests.get(server.url_for("metrics")) + assert response.status_code == HTTPStatus.OK + + # Verify running and waiting requests counts and KV cache usage are zero + running_requests_after, waiting_requests_after, kv_cache_usage_after = ( + _get_running_metrics_from_api(server)) + + assert running_requests_after == 0,\ + (f"Expected 0 running requests after abort, got " + f"{running_requests_after}") + assert waiting_requests_after == 0,\ + (f"Expected 0 waiting requests after abort, got " + f"{waiting_requests_after}") + assert kv_cache_usage_after == 0,\ + (f"Expected 0% KV cache usage after abort, got " + f"{kv_cache_usage_after}") + + +def _get_running_metrics_from_api(server: RemoteOpenAIServer): + """Return (running_count, waiting_count, kv_cache_usage)""" + + response = requests.get(server.url_for("metrics")) + assert response.status_code == HTTPStatus.OK + + # Verify running and waiting requests counts and KV cache usage are zero + running_requests, waiting_requests, kv_cache_usage = None, None, None + + for family in text_string_to_metric_families(response.text): + if family.name == "vllm:num_requests_running": + for sample in family.samples: + if sample.name == "vllm:num_requests_running": + running_requests = sample.value + break + elif family.name == "vllm:num_requests_waiting": + for sample in family.samples: + if sample.name == "vllm:num_requests_waiting": + waiting_requests = sample.value + break + elif family.name == "vllm:gpu_cache_usage_perc": + for sample in family.samples: + if sample.name == "vllm:gpu_cache_usage_perc": + kv_cache_usage = sample.value + break + + assert running_requests is not None + assert waiting_requests is not None + assert kv_cache_usage is not None + + return running_requests, waiting_requests, kv_cache_usage + + def test_metrics_exist_run_batch(use_v1: bool): input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501 diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 246bd014aa690..11ed1c4a9ee4b 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -74,31 +74,44 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): -d '{"messages": [{"role": "assistant", "tool_calls": [{"custom": {"input": "", "name": ""}, "id": "", "type": "custom"}]}]}' \ http://localhost:8000/v1/chat/completions """ # noqa: E501 - if (hasattr(case, "body") and isinstance(case.body, dict) - and "messages" in case.body - and isinstance(case.body["messages"], list) - and len(case.body["messages"]) > 0): + if hasattr(case, "body") and isinstance(case.body, dict): + if ("messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0): - for message in case.body["messages"]: - if not isinstance(message, dict): - continue + for message in case.body["messages"]: + if not isinstance(message, dict): + continue - # Check for invalid file type in tokenize endpoint - if op.method.lower() == "post" and op.path == "/tokenize": - content = message.get("content", []) - if (isinstance(content, list) and len(content) > 0 and any( - item.get("type") == "file" for item in content)): - return False + # Check for invalid file type in tokenize endpoint + if op.method.lower() == "post" and op.path == "/tokenize": + content = message.get("content", []) + if (isinstance(content, list) and len(content) > 0 + and any( + item.get("type") == "file" + for item in content)): + return False + + # Check for invalid tool_calls with non-function types + tool_calls = message.get("tool_calls", []) + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if isinstance(tool_call, dict): + if tool_call.get("type") != "function": + return False + if "custom" in tool_call: + return False + + # Sometimes guided_grammar is generated to be empty + # Causing a server error in EBNF grammar parsing + # https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421 + guided_grammar = case.body.get("guided_grammar") + + if guided_grammar == '': + # Allow None (will be handled as no grammar) + # But skip empty strings + return False - # Check for invalid tool_calls with non-function types - tool_calls = message.get("tool_calls", []) - if isinstance(tool_calls, list): - for tool_call in tool_calls: - if isinstance(tool_call, dict): - if tool_call.get("type") != "function": - return False - if "custom" in tool_call: - return False return True return strategy.filter(no_invalid_types) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index e31a1d077608f..4197583074dfe 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import io + # imports for guided decoding tests import openai +import pybase64 import pytest import regex as re +import torch + +from vllm.entrypoints.openai.serving_engine import OpenAIServing from ...utils import RemoteOpenAIServer @@ -42,3 +48,46 @@ async def test_out_of_vocab_token_ids(): prompt=[999999], max_tokens=5, temperature=0.0) + + +@pytest.mark.parametrize("dtype", + [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize( + "layout", + [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]) +@pytest.mark.parametrize("seq_len", [2, 10]) +@pytest.mark.parametrize("hidden_size", [2, 10]) +def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, + seq_len: int, hidden_size: int): + # construct arbitrary tensors of various dtypes, layouts, and sizes. + # We need to check against different layouts to make sure that if a user + # uses sparse tensors to reduce the transmission size of prompt embeddings, + # we must cast them to dense/strided before passing them into the engine. + # We don't use non-CPU tensors in this test to avoid preemptively + # initializing cuda and break other tests in the suite that fork processes. + # We also need to make sure that we only use devices that are actually + # available in the environment the test is running on. For simplicity, + # we just test against CPU. + tensor = torch.randn((seq_len, hidden_size), dtype=dtype) + if layout == torch.strided: + tensor = tensor.contiguous() + elif layout == torch.sparse_coo: + tensor = tensor.to_sparse_coo() + elif layout == torch.sparse_csc: + tensor = tensor.to_sparse_csc() + elif layout == torch.sparse_csr: + tensor = tensor.to_sparse_csr() + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + encoded_tensor = pybase64.b64encode(buffer.getvalue()) + + loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor) + assert len(loaded_prompt_embeds) == 1 + loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] + assert loaded_tensor.device.type == "cpu" + assert loaded_tensor.layout == torch.strided + torch.testing.assert_close(loaded_tensor, + tensor.to("cpu").to_dense(), + equal_nan=True) diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 1ca52599c519d..72d468db08f65 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -11,18 +11,25 @@ from openai import BadRequestError, NotFoundError, OpenAI from ...utils import RemoteOpenAIServer -pytest.skip(allow_module_level=True, reason="gpt-oss can't run on CI yet.") - MODEL_NAME = "openai/gpt-oss-20b" -DTYPE = "bfloat16" @pytest.fixture(scope="module") -def server(): +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module") +def server(monkeypatch_module: pytest.MonkeyPatch): args = ["--enforce-eager", "--tool-server", "demo"] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest_asyncio.fixture @@ -269,10 +276,11 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_streaming(client: OpenAI, model_name: str): + # TODO: Add back when web search and code interpreter are available in CI prompts = [ "tell me a story about a cat in 20 words", - "What is 13 * 24? Use python to calculate the result.", - "When did Jensen found NVIDIA? Search it and answer the year only.", + # "What is 13 * 24? Use python to calculate the result.", + # "When did Jensen found NVIDIA? Search it and answer the year only.", ] for prompt in prompts: @@ -281,15 +289,15 @@ async def test_streaming(client: OpenAI, model_name: str): input=prompt, reasoning={"effort": "low"}, tools=[ - { - "type": "web_search_preview" - }, - { - "type": "code_interpreter", - "container": { - "type": "auto" - } - }, + # { + # "type": "web_search_preview" + # }, + # { + # "type": "code_interpreter", + # "container": { + # "type": "auto" + # } + # }, ], stream=True, ) @@ -317,6 +325,7 @@ async def test_streaming(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Web search tool is not available in CI yet.") async def test_web_search(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, @@ -331,6 +340,7 @@ async def test_web_search(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") async def test_code_interpreter(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, @@ -436,6 +446,7 @@ async def test_function_calling(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.flaky(reruns=5) async def test_function_calling_multi_turn(client: OpenAI, model_name: str): tools = [ { diff --git a/tests/entrypoints/openai/test_return_token_ids.py b/tests/entrypoints/openai/test_return_token_ids.py new file mode 100644 index 0000000000000..6addcb41c4098 --- /dev/null +++ b/tests/entrypoints/openai/test_return_token_ids.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enforce-eager", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_basic_completion_with_emoji(server): + """Test basic completion with emoji to verify token_ids field.""" + async with server.get_async_client() as client: + # Test with return_token_ids enabled + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Complete this sentence with emojis: I love coding 🚀", + max_tokens=10, + temperature=0, + logprobs=1, + extra_body={"return_token_ids": True}, + ) + + # Check the raw response to see the structure + completion_dict = completion.model_dump() + + # Verify prompt_token_ids field is present in the completion response + assert "prompt_token_ids" in completion_dict["choices"][0] + assert isinstance(completion.choices[0].prompt_token_ids, list) + + # Check against the expected prompt token IDs + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + encoded_tokens = tokenizer.encode( + "Complete this sentence with emojis: I love coding 🚀") + # Check that encoded_tokens is a subsequence of prompt_token_ids + assert any(completion.choices[0].prompt_token_ids[i:i + + len(encoded_tokens)] + == encoded_tokens for i in range( + len(completion.choices[0].prompt_token_ids) - + len(encoded_tokens) + 1)) + + # Verify token_ids field is present in the choice + assert completion.choices[0].token_ids is not None + assert isinstance(completion.choices[0].token_ids, list) + assert len(completion.choices[0].token_ids) > 0 + + # Verify decoding works correctly + decoded_text = tokenizer.decode(completion.choices[0].token_ids) + # The decoded text should contain a <|im_end|> at the end + assert decoded_text.startswith(completion.choices[0].text) + + # Test without return_token_ids (should be None) + completion_without = await client.completions.create( + model=MODEL_NAME, + prompt="Complete this sentence with emojis: I love coding 🚀", + max_tokens=10, + temperature=0, + logprobs=1, + extra_body={"return_token_ids": False}, + ) + + completion_without_dict = completion_without.model_dump() + assert completion_without_dict["choices"][0].get("token_ids") is None + assert completion_without_dict.get("prompt_token_ids") is None + + +@pytest.mark.asyncio +async def test_chat_completion_with_tool_use(server): + """Test chat completion with tool use (get_weather function).""" + tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": + "string", + "description": + "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature", + }, + }, + "required": ["location"], + }, + }, + }] + + async with server.get_async_client() as client: + # Test with return_token_ids enabled + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather like in Paris?" + }, + ], + tools=tools, + tool_choice="auto", + max_tokens=100, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": True}, + ) + + # Verify token_ids field is present in choices + assert response.choices[0].token_ids is not None + assert isinstance(response.choices[0].token_ids, list) + + # Verify prompt_token_ids field is present + assert response.prompt_token_ids is not None + assert isinstance(response.prompt_token_ids, list) + + # Verify the prompt texts and response texts + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + prompt_text = tokenizer.decode(response.prompt_token_ids) + assert prompt_text.startswith( + "<|im_start|>system\nYou are a helpful assistant.") + assert prompt_text.endswith( + "What's the weather like in Paris?<|im_end|>\n" + "<|im_start|>assistant\n") + + response_text = tokenizer.decode(response.choices[0].token_ids) + assert response_text.startswith('<tool_call>\n{"name": "get_weather"') + assert response_text.endswith("</tool_call><|im_end|>") + + # If tool call was made, verify the response structure + if response.choices[0].message.tool_calls: + assert len(response.choices[0].message.tool_calls) > 0 + tool_call = response.choices[0].message.tool_calls[0] + assert tool_call.function.name == "get_weather" + + # Test without return_token_ids + response_without = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather like in Paris?" + }, + ], + tools=tools, + tool_choice="auto", + max_tokens=100, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": False}, + ) + + assert response_without.choices[0].token_ids is None + assert response_without.prompt_token_ids is None + + +@pytest.mark.asyncio +async def test_comparison_with_prompt_logprobs_and_logprobs(server): + """ + Test that token_ids align with prompt_logprobs and + logprobs when return_tokens_as_token_ids is enabled. + """ + async with server.get_async_client() as client: + # Test with both return_token_ids and return_tokens_as_token_ids enabled + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Hello, world! How are you today?", + max_tokens=20, + temperature=0, + echo=True, + logprobs=1, + extra_body={ + "return_token_ids": True, + "return_tokens_as_token_ids": True, + "prompt_logprobs": 1 + }, + ) + + # Verify all fields are present + assert completion.choices[0].token_ids is not None + assert completion.choices[0].prompt_token_ids is not None + assert completion.choices[0].prompt_logprobs is not None + assert completion.choices[0].logprobs is not None + + # Extract token IDs from logprobs + # (when return_tokens_as_token_ids is True) + logprobs_token_ids = [] + for token_str in completion.choices[0].logprobs.tokens: + # Token format is "token_id:12345" when + # return_tokens_as_token_ids is True + if token_str.startswith("token_id:"): + token_id = int(token_str.removeprefix("token_id:")) + logprobs_token_ids.append(token_id) + + # When echo=True, the logprobs include both prompt and response tokens + # The token_ids field should match the the suffix of response portion + # The prompt_token_ids should match the prompt portion + assert len(completion.choices[0].token_ids) < len(logprobs_token_ids) + response_token_ids_length = len(completion.choices[0].token_ids) + assert logprobs_token_ids[-response_token_ids_length:] == \ + completion.choices[0].token_ids + + # Verify tokenizer consistency + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Decode prompt tokens + if completion.choices[0].prompt_token_ids: + prompt_text = tokenizer.decode( + completion.choices[0].prompt_token_ids) + # The decoded prompt should match or close to original prompt + assert "Hello, world" in prompt_text + + # Decode response tokens + if completion.choices[0].token_ids: + response_text = tokenizer.decode(completion.choices[0].token_ids) + assert completion.choices[0].text.endswith(response_text) + + # Test streaming mode + stream = await client.completions.create( + model=MODEL_NAME, + prompt="Tell me a short fact about Python:", + max_tokens=30, + temperature=0, + stream=True, + echo=False, + logprobs=1, + extra_body={ + "return_token_ids": True, + "return_tokens_as_token_ids": True + }, + ) + + # Collect streamed tokens + streamed_prompt_token_ids = [] + streamed_token_ids = [] + streamed_logprob_token_ids = [] + first_chunk = True + async for chunk in stream: + for token_str in chunk.choices[0].logprobs.tokens: + # Token format is "token_id:12345" when + # return_tokens_as_token_ids is True + if token_str.startswith("token_id:"): + token_id = int(token_str.removeprefix("token_id:")) + streamed_logprob_token_ids.append(token_id) + if first_chunk: + streamed_prompt_token_ids = chunk.choices[0].prompt_token_ids + first_chunk = False + streamed_token_ids += chunk.choices[0].token_ids + + # Verify we collected some tokens and first chunk had prompt_token_ids + assert len(streamed_prompt_token_ids) > 0 + assert streamed_token_ids == streamed_logprob_token_ids + + +@pytest.mark.asyncio +async def test_chat_completion_with_emoji_and_token_ids(server): + """Test chat completion with emojis to verify token_ids handling.""" + chat_messages = [ + { + "role": "system", + "content": "You like to use emojis in your responses." + }, + { + "role": "user", + "content": "Repeat after me: I love cats 🐱" + }, + ] + async with server.get_async_client() as client: + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=chat_messages, + max_tokens=50, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": True}, + ) + + # Verify token_ids are present + response_dict = response.model_dump() + assert response.choices[0].token_ids is not None + assert "prompt_token_ids" in response_dict + + # Verify the response contains the expected fields + assert response.choices[0].message.content is not None + + # Decode token_ids and verify consistency + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + decoded_prompt = tokenizer.decode(response.prompt_token_ids) + assert decoded_prompt.startswith( + "<|im_start|>system\nYou like to use emojis in your responses.") + assert decoded_prompt.endswith( + "I love cats 🐱<|im_end|>\n<|im_start|>assistant\n") + + decoded_response = tokenizer.decode(response.choices[0].token_ids) + # The content should match the response text + # except the ending <|im_end|> + assert decoded_response == response.choices[ + 0].message.content + "<|im_end|>" + + # Test with streaming + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=chat_messages, + max_tokens=50, + temperature=0, + stream=True, + extra_body={"return_token_ids": True}, + ) + + collected_content = "" + collected_token_ids = [] + first_chunk = True + + async for chunk in stream: + if first_chunk: + assert chunk.prompt_token_ids is not None + assert isinstance(chunk.prompt_token_ids, list) + # Check the prompt_token_ids match the initial prompt + decoded_prompt_stream = tokenizer.decode( + chunk.prompt_token_ids) + assert decoded_prompt_stream == decoded_prompt + first_chunk = False + else: + chunk_dump = chunk.model_dump() + assert "prompt_token_ids" not in chunk_dump, \ + "Subsequent chunks should not have prompt_token_ids" + + if chunk.choices: + if chunk.choices[0].delta.content: + collected_content += chunk.choices[0].delta.content + # token_ids may not present in all chunks + choice_dump = chunk.choices[0].model_dump() + if "token_ids" in choice_dump: + collected_token_ids.extend(chunk.choices[0].token_ids) + + # Verify we got response and token_ids + assert len(collected_content) > 0 + assert len(collected_token_ids) > 0 + + # Verify token_ids decode properly + decoded_response = tokenizer.decode(collected_token_ids) + assert decoded_response == collected_content + "<|im_end|>" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 8a7892cf6d6aa..10879f0be83c8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -282,9 +282,11 @@ async def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 +@pytest.mark.parametrize("model_type", ["gpt_oss", "any"]) @pytest.mark.asyncio -async def test_serving_chat_did_set_correct_cache_salt(): +async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config = MockModelConfig() + mock_model_config.hf_config.model_type = model_type mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index e103bd206b54c..6009d9aeec935 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -4,19 +4,20 @@ # imports for guided decoding tests import io import json -from unittest.mock import patch import librosa import numpy as np import openai import pytest +import pytest_asyncio import soundfile as sf -from openai._base_client import AsyncAPIClient from vllm.assets.audio import AudioAsset from ...utils import RemoteOpenAIServer +MODEL_NAME = "openai/whisper-large-v3-turbo" +SERVER_ARGS = ["--enforce-eager"] MISTRAL_FORMAT_ARGS = [ "--tokenizer_mode", "mistral", "--config_format", "mistral", "--load_format", "mistral" @@ -37,6 +38,18 @@ def winning_call(): yield f +@pytest.fixture(scope="module") +def server(): + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", @@ -56,58 +69,18 @@ async def test_basic_audio(mary_had_lamb, model_name): language="en", response_format="text", temperature=0.0) - out = json.loads(transcription)['text'] - assert "Mary had a little lamb," in out - - -@pytest.mark.asyncio -async def test_bad_requests(mary_had_lamb): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - - # invalid language - with pytest.raises(openai.BadRequestError): - await client.audio.transcriptions.create(model=model_name, - file=mary_had_lamb, - language="hh", - temperature=0.0) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"]) -async def test_long_audio_request(mary_had_lamb, model_name): - server_args = ["--enforce-eager"] - - mary_had_lamb.seek(0) - audio, sr = librosa.load(mary_had_lamb) - # Add small silence after each audio for repeatability in the split process - audio = np.pad(audio, (0, 1600)) - repeated_audio = np.tile(audio, 10) - # Repeated audio to buffer - buffer = io.BytesIO() - sf.write(buffer, repeated_audio, sr, format='WAV') - buffer.seek(0) - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=model_name, - file=buffer, - language="en", - response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] - counts = out.count("Mary had a little lamb") - assert counts == 10, counts + out = json.loads(transcription) + out_text = out['text'] + out_usage = out['usage'] + assert "Mary had a little lamb," in out_text + assert out_usage["seconds"] == 16, out_usage["seconds"] @pytest.mark.asyncio async def test_non_asr_model(winning_call): # text to text model model_name = "JackFram/llama-68m" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: + with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: client = remote_server.get_async_client() res = await client.audio.transcriptions.create(model=model_name, file=winning_call, @@ -120,157 +93,152 @@ async def test_non_asr_model(winning_call): @pytest.mark.asyncio -async def test_completion_endpoints(): +async def test_bad_requests(mary_had_lamb, client): + # invalid language + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=MODEL_NAME, + file=mary_had_lamb, + language="hh", + temperature=0.0) + + +@pytest.mark.asyncio +async def test_long_audio_request(mary_had_lamb, client): + mary_had_lamb.seek(0) + audio, sr = librosa.load(mary_had_lamb) + # Add small silence after each audio for repeatability in the split process + audio = np.pad(audio, (0, 1600)) + repeated_audio = np.tile(audio, 10) + # Repeated audio to buffer + buffer = io.BytesIO() + sf.write(buffer, repeated_audio, sr, format='WAV') + buffer.seek(0) + transcription = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=buffer, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription) + out_text = out['text'] + out_usage = out['usage'] + counts = out_text.count("Mary had a little lamb") + assert counts == 10, counts + assert out_usage["seconds"] == 161, out_usage["seconds"] + + +@pytest.mark.asyncio +async def test_completion_endpoints(client): # text to text model - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - res = await client.chat.completions.create( - model=model_name, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }]) - err = res.error - assert err["code"] == 400 - assert err[ - "message"] == "The model does not support Chat Completions API" + res = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }]) + err = res.error + assert err["code"] == 400 + assert err["message"] == "The model does not support Chat Completions API" - res = await client.completions.create(model=model_name, prompt="Hello") - err = res.error - assert err["code"] == 400 - assert err["message"] == "The model does not support Completions API" + res = await client.completions.create(model=MODEL_NAME, prompt="Hello") + err = res.error + assert err["code"] == 400 + assert err["message"] == "The model does not support Completions API" @pytest.mark.asyncio -async def test_streaming_response(winning_call): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] +async def test_streaming_response(winning_call, client): transcription = "" - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - res_no_stream = await client.audio.transcriptions.create( - model=model_name, - file=winning_call, - response_format="json", - language="en", - temperature=0.0) - # Unfortunately this only works when the openai client is patched - # to use streaming mode, not exposed in the transcription api. - original_post = AsyncAPIClient.post + res_no_stream = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=winning_call, + response_format="json", + language="en", + temperature=0.0) + res = await client.audio.transcriptions.create(model=MODEL_NAME, + file=winning_call, + language="en", + temperature=0.0, + stream=True, + timeout=30) + # Reconstruct from chunks and validate + async for chunk in res: + text = chunk.choices[0]['delta']['content'] + transcription += text - async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True - return await original_post(*args, **kwargs) - - with patch.object(AsyncAPIClient, "post", new=post_with_stream): - client = remote_server.get_async_client() - res = await client.audio.transcriptions.create( - model=model_name, - file=winning_call, - language="en", - temperature=0.0, - extra_body=dict(stream=True), - timeout=30) - # Reconstruct from chunks and validate - async for chunk in res: - # just a chunk - text = chunk.choices[0]['delta']['content'] - transcription += text - - assert transcription == res_no_stream.text + assert transcription == res_no_stream.text @pytest.mark.asyncio -async def test_stream_options(winning_call): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - original_post = AsyncAPIClient.post - - async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True - return await original_post(*args, **kwargs) - - with patch.object(AsyncAPIClient, "post", new=post_with_stream): - client = remote_server.get_async_client() - res = await client.audio.transcriptions.create( - model=model_name, - file=winning_call, - language="en", - temperature=0.0, - extra_body=dict(stream=True, - stream_include_usage=True, - stream_continuous_usage_stats=True), - timeout=30) - final = False - continuous = True - async for chunk in res: - if not len(chunk.choices): - # final usage sent - final = True - else: - continuous = continuous and hasattr(chunk, 'usage') - assert final and continuous +async def test_stream_options(winning_call, client): + res = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=winning_call, + language="en", + temperature=0.0, + stream=True, + extra_body=dict(stream_include_usage=True, + stream_continuous_usage_stats=True), + timeout=30) + final = False + continuous = True + async for chunk in res: + if not len(chunk.choices): + # final usage sent + final = True + else: + continuous = continuous and hasattr(chunk, 'usage') + assert final and continuous @pytest.mark.asyncio -async def test_sampling_params(mary_had_lamb): +async def test_sampling_params(mary_had_lamb, client): """ Compare sampling with params and greedy sampling to assert results are different when extreme sampling parameters values are picked. """ - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - temperature=0.8, - extra_body=dict(seed=42, - repetition_penalty=1.9, - top_k=12, - top_p=0.4, - min_p=0.5, - frequency_penalty=1.8, - presence_penalty=2.0)) + transcription = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + temperature=0.8, + extra_body=dict(seed=42, + repetition_penalty=1.9, + top_k=12, + top_p=0.4, + min_p=0.5, + frequency_penalty=1.8, + presence_penalty=2.0)) - greedy_transcription = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - temperature=0.0, - extra_body=dict(seed=42)) + greedy_transcription = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + temperature=0.0, + extra_body=dict(seed=42)) - assert greedy_transcription.text != transcription.text + assert greedy_transcription.text != transcription.text @pytest.mark.asyncio -async def test_audio_prompt(mary_had_lamb): - model_name = "openai/whisper-large-v3-turbo" - server_args = ["--enforce-eager"] +async def test_audio_prompt(mary_had_lamb, client): prompt = "This is a speech, recorded in a phonograph." - with RemoteOpenAIServer(model_name, server_args) as remote_server: - #Prompts should not omit the part of original prompt while transcribing. - prefix = "The first words I spoke in the original phonograph" - client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] - assert prefix in out - transcription_wprompt = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - response_format="text", - prompt=prompt, - temperature=0.0) - out_prompt = json.loads(transcription_wprompt)['text'] - assert prefix in out_prompt + #Prompts should not omit the part of original prompt while transcribing. + prefix = "The first words I spoke in the original phonograph" + transcription = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert prefix in out + transcription_wprompt = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + prompt=prompt, + temperature=0.0) + out_prompt = json.loads(transcription_wprompt)['text'] + assert prefix in out_prompt diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index bfa9bdef1c001..f4f5c66f2deeb 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -4,18 +4,21 @@ import io # imports for guided decoding tests import json -from unittest.mock import patch +import httpx import librosa import numpy as np import pytest +import pytest_asyncio import soundfile as sf -from openai._base_client import AsyncAPIClient from vllm.assets.audio import AudioAsset from ...utils import RemoteOpenAIServer +MODEL_NAME = "openai/whisper-small" +SERVER_ARGS = ["--enforce-eager"] + @pytest.fixture def foscolo(): @@ -25,50 +28,23 @@ def foscolo(): yield f -# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation! -@pytest.mark.asyncio -async def test_basic_audio(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - translation = await client.audio.translations.create( - model=model_name, - file=foscolo, - response_format="text", - # TODO remove once language detection is implemented - extra_body=dict(language="it"), - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() - assert "greek sea" in out +@pytest.fixture(scope="module") +def server(): + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server: + yield remote_server -@pytest.mark.asyncio -async def test_audio_prompt(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - # Condition whisper on starting text - prompt = "Nor have I ever" - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - transcription = await client.audio.translations.create( - model=model_name, - file=foscolo, - prompt=prompt, - extra_body=dict(language="it"), - response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] - assert "Nor will I ever touch the sacred" not in out - assert prompt not in out +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio async def test_non_asr_model(foscolo): # text to text model model_name = "JackFram/llama-68m" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: + with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: client = remote_server.get_async_client() res = await client.audio.translations.create(model=model_name, file=foscolo, @@ -78,81 +54,117 @@ async def test_non_asr_model(foscolo): assert err["message"] == "The model does not support Translations API" +# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation! @pytest.mark.asyncio -async def test_streaming_response(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] +async def test_basic_audio(foscolo, client): + translation = await client.audio.translations.create( + model=MODEL_NAME, + file=foscolo, + response_format="text", + # TODO remove once language detection is implemented + extra_body=dict(language="it"), + temperature=0.0) + out = json.loads(translation)['text'].strip().lower() + assert "greek sea" in out + + +@pytest.mark.asyncio +async def test_audio_prompt(foscolo, client): + # Condition whisper on starting text + prompt = "Nor have I ever" + transcription = await client.audio.translations.create( + model=MODEL_NAME, + file=foscolo, + prompt=prompt, + extra_body=dict(language="it"), + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "Nor will I ever touch the sacred" not in out + assert prompt not in out + + +@pytest.mark.asyncio +async def test_streaming_response(foscolo, client, server): translation = "" - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - res_no_stream = await client.audio.translations.create( - model=model_name, - file=foscolo, - response_format="json", - extra_body=dict(language="it"), - temperature=0.0) - # Unfortunately this only works when the openai client is patched - # to use streaming mode, not exposed in the translation api. - original_post = AsyncAPIClient.post + res_no_stream = await client.audio.translations.create( + model=MODEL_NAME, + file=foscolo, + response_format="json", + extra_body=dict(language="it"), + temperature=0.0) + # Stream via HTTPX since OpenAI translation client doesn't expose streaming + url = server.url_for("v1/audio/translations") + headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} + data = { + "model": MODEL_NAME, + "language": "it", + "stream": True, + "temperature": 0.0, + } + foscolo.seek(0) + async with httpx.AsyncClient() as http_client: + files = {"file": foscolo} + async with http_client.stream("POST", + url, + headers=headers, + data=data, + files=files) as response: + async for line in response.aiter_lines(): + if not line: + continue + if line.startswith("data: "): + line = line[len("data: "):] + if line.strip() == "[DONE]": + break + chunk = json.loads(line) + text = chunk["choices"][0].get("delta", {}).get("content") + translation += text or "" - async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True - return await original_post(*args, **kwargs) - - with patch.object(AsyncAPIClient, "post", new=post_with_stream): - client = remote_server.get_async_client() - res = await client.audio.translations.create(model=model_name, - file=foscolo, - temperature=0.0, - extra_body=dict( - stream=True, - language="it")) - # Reconstruct from chunks and validate - async for chunk in res: - # just a chunk - text = chunk.choices[0]['delta']['content'] - translation += text - - assert translation == res_no_stream.text + assert translation == res_no_stream.text @pytest.mark.asyncio -async def test_stream_options(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - original_post = AsyncAPIClient.post - - async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True - return await original_post(*args, **kwargs) - - with patch.object(AsyncAPIClient, "post", new=post_with_stream): - client = remote_server.get_async_client() - res = await client.audio.translations.create( - model=model_name, - file=foscolo, - temperature=0.0, - extra_body=dict(language="it", - stream=True, - stream_include_usage=True, - stream_continuous_usage_stats=True)) - final = False - continuous = True - async for chunk in res: - if not len(chunk.choices): +async def test_stream_options(foscolo, client, server): + url = server.url_for("v1/audio/translations") + headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} + data = { + "model": MODEL_NAME, + "language": "it", + "stream": True, + "stream_include_usage": True, + "stream_continuous_usage_stats": True, + "temperature": 0.0, + } + foscolo.seek(0) + final = False + continuous = True + async with httpx.AsyncClient() as http_client: + files = {"file": foscolo} + async with http_client.stream("POST", + url, + headers=headers, + data=data, + files=files) as response: + async for line in response.aiter_lines(): + if not line: + continue + if line.startswith("data: "): + line = line[len("data: "):] + if line.strip() == "[DONE]": + break + chunk = json.loads(line) + choices = chunk.get("choices", []) + if not choices: # final usage sent final = True else: - continuous = continuous and hasattr(chunk, 'usage') - assert final and continuous + continuous = continuous and ("usage" in chunk) + assert final and continuous @pytest.mark.asyncio -async def test_long_audio_request(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - +async def test_long_audio_request(foscolo, client): foscolo.seek(0) audio, sr = librosa.load(foscolo) repeated_audio = np.tile(audio, 2) @@ -160,13 +172,11 @@ async def test_long_audio_request(foscolo): buffer = io.BytesIO() sf.write(buffer, repeated_audio, sr, format='WAV') buffer.seek(0) - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - translation = await client.audio.translations.create( - model=model_name, - file=buffer, - extra_body=dict(language="it"), - response_format="text", - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() - assert out.count("greek sea") == 2 + translation = await client.audio.translations.create( + model=MODEL_NAME, + file=buffer, + extra_body=dict(language="it"), + response_format="text", + temperature=0.0) + out = json.loads(translation)['text'].strip().lower() + assert out.count("greek sea") == 2 diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py index 79b6ce059ce49..121c0413e1af7 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/openai/test_truncation.py @@ -64,6 +64,28 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): assert response["usage"]["prompt_tokens"] == truncation_size +@pytest.mark.asyncio +async def test_zero_truncation_size(client: openai.AsyncOpenAI): + truncation_size = 0 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + with pytest.raises(openai.BadRequestError) as err: + await client.post(path="embeddings", cast_to=object, body={**kwargs}) + + assert err.value.status_code == 400 + error_details = err.value.response.json()["error"] + + assert error_details["type"] == "BadRequestError" + assert "This model's maximum context length is" in error_details["message"] + assert "tokens in the input for embedding generation" in error_details[ + "message"] + assert "Please reduce the length of the input" in error_details["message"] + + @pytest.mark.asyncio async def test_bigger_truncation_size(client: openai.AsyncOpenAI): truncation_size = max_model_len + 1 @@ -74,18 +96,15 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): } with pytest.raises(openai.BadRequestError) as err: - err = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + await client.post(path="embeddings", cast_to=object, body={**kwargs}) - assert str(err) == f"""openai.BadRequestError: - Error code: 400 - {{'object': 'error', - 'message': 'truncate_prompt_tokens value - ({truncation_size}) - is greater than max_model_len ({max_model_len}). - Please, select a smaller truncation size.', - 'type': 'BadRequestError', - 'param': None, 'code': 400}}""" + assert err.value.status_code == 400 + error_details = err.value.response.json()["error"] + assert error_details["type"] == "BadRequestError" + expected_message = ("truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + assert error_details["message"] == expected_message @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 8259a81d7b6a1..106ec121a422e 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -6,8 +6,6 @@ import json import openai import pytest import pytest_asyncio -import requests -from PIL import Image from transformers import AutoProcessor from vllm.multimodal.utils import encode_image_base64, fetch_image @@ -88,7 +86,7 @@ def get_hf_prompt_tokens(model_name, content, image_url): "role": "user", "content": f"{placeholder}{content}", }] - images = [Image.open(requests.get(image_url, stream=True).raw)] + images = [fetch_image(image_url)] prompt = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 4e6a21058658b..d3cc2fac6af57 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -5,7 +5,6 @@ import json import pytest import requests -from PIL import Image from transformers import AutoProcessor from vllm.entrypoints.openai.protocol import EmbeddingResponse @@ -64,7 +63,7 @@ def get_hf_prompt_tokens(model_name, content, image_url): placeholder = "<|image_1|> " prompt = f"{placeholder}{content}" - images = [Image.open(requests.get(image_url, stream=True).raw)] + images = [fetch_image(image_url)] inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py new file mode 100644 index 0000000000000..28b1f8358d80b --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest + +from ....utils import RemoteOpenAIServer + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci" + +SERVER_ARGS = [ + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enable-lora", + "--lora-modules", + f"{LORA_MODEL}={LORA_MODEL}", +] + +TOOLS = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": + "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + }, + }, + "required": ["location"], + }, + }, +}] + +MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] + + +@pytest.mark.asyncio +async def test_non_streaming_tool_call(): + """Test tool call in non-streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + response = await client.chat.completions.create( + model=LORA_MODEL, + messages=MESSAGES, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + ) + + assert response.choices + choice = response.choices[0] + message = choice.message + + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None + + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_current_weather" + + arguments = json.loads(tool_call.function.arguments) + assert "location" in arguments + assert "Boston" in arguments["location"] + print("\n[Non-Streaming Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_streaming_tool_call(): + """Test tool call in streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + stream = await client.chat.completions.create( + model=LORA_MODEL, + messages=MESSAGES, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + stream=True, + ) + + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue + + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} + + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index][ + "arguments"] += tool_chunk.function.arguments + + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] + + assert reconstructed_tool_call["name"] == "get_current_weather" + + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "location" in arguments + assert "Boston" in arguments["location"] + print("\n[Streaming Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md new file mode 100644 index 0000000000000..58572c3a6fbc1 --- /dev/null +++ b/tests/evals/gsm8k/README.md @@ -0,0 +1,35 @@ +# GSM8K Accuracy Evaluation + +This directory contains a replacement for the lm-eval-harness GSM8K evaluation, using an isolated GSM8K script and vLLM server for better performance and control. + +## Usage + +### Run tests with pytest (like buildkite) + +```bash +pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 +``` + +### Run standalone evaluation script + +```bash +# Start vLLM server first +vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 + +# Run evaluation +python tests/gsm8k/gsm8k_eval.py --port 8000 +``` + +## Configuration Format + +Model configs in `configs/` directory use this YAML format: + +```yaml +model_name: "Qwen/Qwen2.5-1.5B-Instruct" +accuracy_threshold: 0.54 # Minimum expected accuracy +num_questions: 1319 # Number of questions (default: full test set) +num_fewshot: 5 # Few-shot examples from train set +max_model_len: 4096 # Model context length +``` diff --git a/tests/evals/gsm8k/__init__.py b/tests/evals/gsm8k/__init__.py new file mode 100644 index 0000000000000..0fec1fe5bcdfd --- /dev/null +++ b/tests/evals/gsm8k/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml new file mode 100644 index 0000000000000..caa0448f23d48 --- /dev/null +++ b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml @@ -0,0 +1,5 @@ +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" +accuracy_threshold: 0.74 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml new file mode 100644 index 0000000000000..615aa69a2d2b6 --- /dev/null +++ b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8" +accuracy_threshold: 0.31 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml new file mode 100644 index 0000000000000..c5dbceeeb2b45 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -0,0 +1,5 @@ +model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" +accuracy_threshold: 0.45 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml new file mode 100644 index 0000000000000..5319ada30f645 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" +accuracy_threshold: 0.60 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml new file mode 100644 index 0000000000000..c39fb979d98ac --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml @@ -0,0 +1,5 @@ +model_name: "Qwen/Qwen3-0.6B-FP8" +accuracy_threshold: 0.375 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/models-small.txt b/tests/evals/gsm8k/configs/models-small.txt new file mode 100644 index 0000000000000..afd1065b9191b --- /dev/null +++ b/tests/evals/gsm8k/configs/models-small.txt @@ -0,0 +1,5 @@ +Qwen3-0.6B-FP8.yaml +Llama-3.2-1B-Instruct-INT8-CT.yaml +Llama-3-8B-Instruct-nonuniform-CT.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +Qwen1.5-MoE-W4A16-CT.yaml diff --git a/tests/evals/gsm8k/conftest.py b/tests/evals/gsm8k/conftest.py new file mode 100644 index 0000000000000..d96b0a66ede2b --- /dev/null +++ b/tests/evals/gsm8k/conftest.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path + + +def pytest_addoption(parser): + """Add custom command line options.""" + parser.addoption("--config-list-file", + default="configs/models-small.txt", + help="File containing list of config files to test") + parser.addoption("--tp-size", + default=1, + type=int, + help="Tensor parallel size") + + +def pytest_generate_tests(metafunc): + """Generate test parameters from config files.""" + if "config_filename" in metafunc.fixturenames: + config_list_file = metafunc.config.getoption("--config-list-file") + tp_size = metafunc.config.getoption("--tp-size") + + # Handle both relative and absolute paths + config_list_path = Path(config_list_file) + if not config_list_path.is_absolute(): + # If relative, try relative to test directory first + test_dir_path = Path(__file__).parent / config_list_file + if test_dir_path.exists(): + config_list_path = test_dir_path + else: + # Try relative to current working directory + config_list_path = Path.cwd() / config_list_file + + print(f"Looking for config list at: {config_list_path}") + + config_files = [] + if config_list_path.exists(): + # Determine config directory (same directory as the list file) + config_dir = config_list_path.parent + + with open(config_list_path) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + config_path = config_dir / line + print(f"Checking config file: {config_path}") + if config_path.exists(): + config_files.append(config_path) + print(f" ✓ Found: {config_path}") + else: + print(f" ✗ Missing: {config_path}") + else: + print(f"Config list file not found: {config_list_path}") + + # Generate test parameters + if config_files: + metafunc.parametrize(["config_filename", "tp_size"], + [(config_file, int(tp_size)) + for config_file in config_files], + ids=[ + f"{config_file.stem}-tp{tp_size}" + for config_file in config_files + ]) + else: + print("No config files found, test will be skipped") diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py new file mode 100644 index 0000000000000..7d0ce25f75dd4 --- /dev/null +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Isolated GSM8K evaluation script for vLLM serve endpoint. +""" + +import argparse +import ast +import asyncio +import json +import os +import time +from collections.abc import Generator +from typing import Optional, Union + +import aiohttp +import numpy as np +import regex as re +import requests +from tqdm.asyncio import tqdm + +INVALID = -9999999 + + +def download_and_cache_file(url: str, filename: Optional[str] = None) -> str: + """Download and cache a file from a URL.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=1024): + f.write(chunk) + + return filename + + +def load_gsm8k_data() -> tuple[list[dict], list[dict]]: + """Load GSM8K train and test data""" + train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl" + test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + + train_file = download_and_cache_file(train_url) + test_file = download_and_cache_file(test_url) + + train_data = list(read_jsonl(train_file)) + test_data = list(read_jsonl(test_file)) + + return train_data, test_data + + +def read_jsonl(filename: str) -> Generator[dict, None, None]: + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if not line.startswith("#"): + yield json.loads(line) + + +def get_answer_value(answer_str: str) -> int: + """Extract the numerical answer from the response.""" + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +async def call_vllm_api(session: aiohttp.ClientSession, + prompt: str, + temperature: float, + max_tokens: int, + stop: Optional[list[str]] = None, + url: Optional[str] = None, + seed: Optional[int] = None) -> str: + """Call vLLM's OpenAI-compatible completions endpoint.""" + data = { + "prompt": prompt, + "temperature": temperature, + "max_tokens": max_tokens, + "stop": stop, + } + if seed is not None: + data["seed"] = seed + + try: + async with session.post(f"{url}/v1/completions", + json=data) as response: + response.raise_for_status() + result = await response.json() + return result["choices"][0]["text"] + except Exception as e: + print(f"Error calling vLLM API: {e}") + return "" + + +def evaluate_gsm8k(num_questions: int = 1319, + num_shots: int = 5, + max_tokens: int = 256, + host: str = "http://127.0.0.1", + port: int = 8000, + temperature: float = 0.0, + seed: Optional[int] = 42) -> dict[str, Union[float, int]]: + """ + Evaluate GSM8K accuracy using vLLM serve endpoint. + + Returns dict with accuracy, invalid_rate, latency, etc. + """ + base_url = f"{host}:{port}" + + # Load GSM8K train and test data + train_data, test_data = load_gsm8k_data() + + # Limit to available test questions + num_questions = min(num_questions, len(test_data)) + + # Build few-shot examples from train split (like lm-eval does) + few_shot_examples = "" + for i in range(num_shots): + few_shot_examples += (f"Question: {train_data[i]['question']}\n" + f"Answer: {train_data[i]['answer']}\n\n") + + # Prepare test questions and labels from test split + questions = [] + labels = [] + for i in range(num_questions): + questions.append(f"Question: {test_data[i]['question']}\nAnswer:") + labels.append(get_answer_value(test_data[i]["answer"])) + + assert all(label != INVALID for label in labels), "Some labels are invalid" + + # Run evaluation + async def run_async_evaluation(): + states: list[str] = [""] * num_questions + + async def get_answer(session: aiohttp.ClientSession, i: int) -> str: + prompt = few_shot_examples + questions[i] + answer = await call_vllm_api( + session=session, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + stop=["Question", "Assistant:", "<|separator|>"], + url=base_url, + seed=seed, + ) + states[i] = answer + return answer + + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( + total=600)) as session: + tasks = [get_answer(session, i) for i in range(num_questions)] + await tqdm.gather(*tasks, desc="Evaluating") + + return states + + print(f"Running GSM8K evaluation: {num_questions} questions, " + f"{num_shots}-shot") + + tic = time.perf_counter() + states = asyncio.run(run_async_evaluation()) + latency = time.perf_counter() - tic + + # Compute metrics + preds = [get_answer_value(state) for state in states] + accuracy = np.mean(np.array(preds) == np.array(labels)) + invalid_rate = np.mean(np.array(preds) == INVALID) + + result = { + "accuracy": accuracy, + "invalid_rate": invalid_rate, + "latency": latency, + "questions_per_second": num_questions / latency, + "num_questions": num_questions, + "num_shots": num_shots, + "max_tokens": max_tokens, + "timestamp": time.time(), + } + + return result + + +def main() -> None: + parser = argparse.ArgumentParser( + description="GSM8K evaluation for vLLM serve") + parser.add_argument("--num-shots", + type=int, + default=5, + help="Number of few-shot examples") + parser.add_argument("--num-questions", + type=int, + default=1319, + help="Number of questions to evaluate") + parser.add_argument("--max-tokens", + type=int, + default=256, + help="Max tokens for generation") + parser.add_argument("--host", + type=str, + default="http://127.0.0.1", + help="Host URL") + parser.add_argument("--port", type=int, default=8000, help="Port number") + parser.add_argument("--temperature", + type=float, + default=0.0, + help="Temperature for generation") + parser.add_argument("--seed", + type=int, + default=42, + help="Random seed for reproducibility") + parser.add_argument("--save-results", + type=str, + help="Save results to JSON file") + + args = parser.parse_args() + + result = evaluate_gsm8k( + num_questions=args.num_questions, + num_shots=args.num_shots, + max_tokens=args.max_tokens, + host=args.host, + port=args.port, + temperature=args.temperature, + seed=args.seed, + ) + + # Print results to terminal + print("\nResults:") + print(f"Accuracy: {result['accuracy']:.3f}") + print(f"Invalid responses: {result['invalid_rate']:.3f}") + print(f"Total latency: {result['latency']:.3f} s") + print(f"Questions per second: {result['questions_per_second']:.3f}") + + # Optional file saving + if args.save_results: + with open(args.save_results, "w") as f: + json.dump(result, f, indent=2) + print(f"Results saved to {args.save_results}") + + +if __name__ == "__main__": + main() diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py new file mode 100644 index 0000000000000..a12dd49dbea6d --- /dev/null +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GSM8K evaluation using vLLM server and isolated GSM8K script. +Replacement for lm-eval-harness with better performance and control. + +Usage: +pytest -s -v test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 +""" + +import yaml + +from tests.utils import RemoteOpenAIServer + +from .gsm8k_eval import evaluate_gsm8k + +RTOL = 0.08 # Relative tolerance for accuracy comparison + + +def launch_gsm8k_eval(eval_config, server_url, tp_size): + """Launch GSM8K evaluation using our isolated script.""" + # Extract host and port from server URL + if "://" in server_url: + server_url = server_url.split("://")[1] + + host_port = server_url.split("/")[0] # Remove path if present + if ":" in host_port: + host, port = host_port.split(":") + port = int(port) + else: + host = host_port + port = 8000 + + # Add http:// prefix if not present + if not host.startswith("http"): + host = f"http://{host}" + + # Run GSM8K evaluation + results = evaluate_gsm8k( + num_questions=eval_config["num_questions"], + num_shots=eval_config["num_fewshot"], + host=host, + port=port, + ) + + return results + + +def test_gsm8k_correctness_param(config_filename, tp_size): + """Test GSM8K correctness for a given model configuration.""" + eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) + + # Server arguments + server_args = [ + "--max-model-len", + str(eval_config.get("max_model_len", 4096)), + "--enforce-eager", + "--trust-remote-code", + "--tensor-parallel-size", + str(tp_size), + ] + + # Launch server and run evaluation + with RemoteOpenAIServer(eval_config["model_name"], + server_args, + max_wait_seconds=480) as remote_server: + server_url = remote_server.url_for("v1") + + results = launch_gsm8k_eval(eval_config, server_url, tp_size) + + # Check accuracy against threshold + measured_accuracy = results["accuracy"] + expected_accuracy = eval_config["accuracy_threshold"] + + print(f"GSM8K Results for {eval_config['model_name']}:") + print(f" Accuracy: {measured_accuracy:.3f}") + print(f" Expected: {expected_accuracy:.3f}") + print(f" Questions: {results['num_questions']}") + print(f" Invalid rate: {results['invalid_rate']:.3f}") + print(f" Latency: {results['latency']:.1f}s") + print(f" QPS: {results['questions_per_second']:.1f}") + + # Verify accuracy is within tolerance + assert measured_accuracy >= expected_accuracy - RTOL, ( + f"Accuracy too low: {measured_accuracy:.3f} < " + f"{expected_accuracy:.3f} - {RTOL:.3f}") + + print(f"✅ GSM8K test passed for {eval_config['model_name']}") diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index bfeafaa9e27e6..aea166da3af2f 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -81,6 +81,9 @@ def test_env( m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") + if name == "FLASHINFER" and not use_v1: + pytest.skip("FlashInfer backend is only available on V1 engine") + if device == "cpu": if not use_v1: pytest.skip("CPU backend only supports V1") diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 8c3cc8cba9d9f..cbf11da63cab9 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -709,14 +709,15 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("max_seq_len", [512]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("kv_cache_dtype", - ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, - num_blocks, max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, + block_size, num_blocks, + max_seq_len, batch_size, dtype, + kv_cache_dtype, device): entry_size = kv_lora_rank + qk_rope_head_dim + scale = torch.tensor(0.1, dtype=torch.float32, device=device) src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) @@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, perm = torch.randperm(num_blocks, device=device) block_table[b, :] = perm - dst = torch.zeros((total_tokens, entry_size), - dtype=src_cache.dtype, - device=device) + dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device) expected_batches = [] for b in range(batch_size): @@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, gathered_rows = [] for i in range(tot - 1): - gathered_rows.append(src_cache[blocks[i]]) + block_data = src_cache[blocks[i]] + if kv_cache_dtype == "fp8": + dequantized_block = torch.empty_like(block_data, dtype=dtype) + ops.convert_fp8(dequantized_block, block_data, scale.item()) + gathered_rows.append(dequantized_block) + else: + gathered_rows.append(block_data) remaining = s - (tot - 1) * block_size - gathered_rows.append(src_cache[blocks[-1], :remaining, :]) + last_block_data = src_cache[blocks[-1], :remaining, :] + if kv_cache_dtype == "fp8": + dequantized_last_block = torch.empty_like(last_block_data, + dtype=dtype) + ops.convert_fp8(dequantized_last_block, last_block_data, + scale.item()) + gathered_rows.append(dequantized_last_block) + else: + gathered_rows.append(last_block_data) batch_expected = torch.cat(gathered_rows, dim=0) expected_batches.append(batch_expected) expected = torch.cat(expected_batches, dim=0) opcheck( - torch.ops._C_cache_ops.gather_cache, - (src_cache, dst, block_table, cu_seq_lens, batch_size, None), + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, None), test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, kv_cache_dtype, + scale, None) torch.testing.assert_close(dst, expected) diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index be78f0e4fcc62..a821a74aba93d 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv( workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) - ) + use_tensor_cores=True) wrapper.plan( kv_indptr, kv_indices, @@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 - use_tensor_cores = (num_query_heads // num_kv_heads) > 4 + use_tensor_cores = True kv_cache_dtype = torch.float8_e4m3fn query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 53e225ea3ea6c..8d0a11d8eb8ab 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -6,28 +6,19 @@ import flashinfer import pytest import torch +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm.platforms import current_platform +from vllm.utils import round_up if not current_platform.is_device_capability(100): pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True) FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) - -MAX_Q_LEN = 1024 -MAX_KV_LEN = 4096 -BATCH_SIZES = [4, 12] -NUM_HEADS = [(16, 16), (40, 8)] -HEAD_SIZES = [128] -BLOCK_SIZES = [16] -KV_LAYOUTS = ["HND"] -DTYPES = [torch.bfloat16] -KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. -SOFT_CAPS = [None, 50.0] +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn): return x_scl_sat.to(dtype), scale.float().reciprocal() -@pytest.mark.parametrize("batch_size", BATCH_SIZES) +DTYPE = [torch.bfloat16] +QUANT_DTYPES = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), +] +BATCH_SIZE = [4, 12] +MAX_SEQ_LENS = [(1024, 4096)] +NUM_HEADS = [(64, 8), (40, 8)] +HEAD_SIZE = [128] +KV_LAYOUT = ["HND"] # currently only HND is supported +BLOCK_SIZE = [16] +SOFT_CAP = [None, 50.0] + +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("soft_cap", SOFT_CAP) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], batch_size: int, + max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, - block_size: int, kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], + block_size: int, soft_cap: Optional[float], ) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - torch.set_default_device("cuda") current_platform.seed_everything(0) - kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN - max_kv_len = torch.max(kv_lens).item() - num_seqs = len(kv_lens) + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 + _, max_kv_len = max_seq_lens - scale = head_size**-0.5 + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + sm_scale = float(1.0 / (head_size**0.5)) kv_cache_shape = None if kv_layout == "NHD": @@ -83,156 +93,39 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - k_scale = v_scale = kv_scale - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + block_size - 1) // block_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % block_size - if kv_last_page_len == 0: - kv_last_page_len = block_size - kv_last_page_lens.append(kv_last_page_len) + query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_query_heads // num_kv_heads) > 4)) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - sm_scale=scale, - q_data_type=dtype, - kv_data_type=kv_cache_dtype, - logits_soft_cap=soft_cap) - - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) - - # TRTLLM Decode - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - output_trtllm = torch.empty(query.shape, dtype=dtype) - flashinfer.decode.trtllm_batch_decode_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, - workspace_buffer=workspace_buffer, - block_tables=block_tables, - seq_lens=kv_lens_tensor, - max_seq_len=max_kv_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) - - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - output_trtllm))}" - - -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) -@pytest.mark.parametrize("soft_cap", [None]) -@torch.inference_mode -def test_flashinfer_trtllm_prefill_with_baseline( - batch_size: int, - num_heads: tuple[int, int], - head_size: int, - block_size: int, - kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], - soft_cap: Optional[float], -) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - if dtype != kv_cache_dtype: - pytest.skip(f"Not supported dtype({dtype}) with " - "kv_cache_dtype({kv_cache_dtype})") - - torch.set_default_device("cuda") - current_platform.seed_everything(0) - - q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32) - q_lens[-1] = MAX_Q_LEN - max_q_len = torch.max(q_lens).item() - q_indptr = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(q_lens, dim=0, dtype=torch.int32), - ]) - - kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN - - seq_lens = kv_lens + q_lens + seq_lens = kv_lens max_seq_len = torch.max(seq_lens).item() - num_seqs = len(seq_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - - scale = head_size**-0.5 - - query = torch.randn(torch.sum(q_lens).item(), - num_query_heads, - head_size, - dtype=dtype) - - kv_cache_shape = None - if kv_layout == "NHD": - kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) - elif kv_layout == "HND": - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale else: - raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint(0, NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), + (batch_size, max_num_blocks_per_seq), dtype=torch.int32) - k_scale = v_scale = kv_scale kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 num_blocks = (seq_len + block_size - 1) // block_size @@ -246,48 +139,259 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + # Baseline Decode + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, use_tensor_cores=True) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap) + + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + o_sf_scale = None + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) + + # TRTLLM Decode + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + + flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, + out=output_trtllm, + ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) + + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 3e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 2e-2 + + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ + f"{torch.max(torch.abs(output - output_trtllm))}" + + +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("soft_cap", [None]) +@torch.inference_mode +def test_flashinfer_trtllm_prefill_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], + batch_size: int, + max_seq_lens: tuple[int, int], + num_heads: tuple[int, int], + head_size: int, + kv_layout: str, + block_size: int, + soft_cap: Optional[float], +) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(0) + + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + if q_quant_dtype != kv_quant_dtype: + pytest.skip("Skipped mixed QKV dtypes for prefill") + + max_q_len, max_kv_len = max_seq_lens + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32) + q_lens[-1] = max_q_len + q_indptr = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ]) + + query = torch.randn(torch.sum(q_lens).item(), + num_qo_heads, + head_size, + dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (batch_size, max_num_blocks_per_seq), + dtype=torch.int32) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) + + # Baseline Prefill wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout) wrapper.plan(q_indptr, kv_indptr, kv_indices, kv_last_page_lens, - num_query_heads, + num_qo_heads, num_kv_heads, head_size, block_size, causal=True, - sm_scale=scale, + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache_dtype, + kv_data_type=dtype, logits_soft_cap=soft_cap) - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + o_sf_scale = None + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) + + # TRTLLM Prefill + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) - # TRTLLM Decode - output_trtllm = torch.empty(query.shape, dtype=dtype) flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, + query=query, + kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 4e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 + + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 81841be583528..abcfe828d5aca 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max( (x * x + y * y).sum().item(), 1e-12) - assert cos_diff < 1e-5 + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ if not is_flashmla_supported()[0] else "FlashMLA is supported" @@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ reason=FLASH_MLA_UNSUPPORTED_REASON) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) -@pytest.mark.parametrize("mean_sk", [4096, 8192]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) @pytest.mark.parametrize("h_q", [16, 32, 64, 128]) @pytest.mark.parametrize("h_kv", [1]) @pytest.mark.parametrize("d", [576]) @@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("torch_dtype", + [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen, dtype): + varlen, torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(dtype) + if torch_dtype == torch.float8_e4m3fn: + init_dtype = torch.bfloat16 + else: + init_dtype = torch_dtype + torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}") + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv) + init_dtype = q.dtype + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) + else: + descale_q = None + descale_k = None + def flash_mla(): return flash_mla_with_kvcache( q, @@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, + descale_k=descale_k, ) def scaled_dot_product_attention(query, key, value, is_causal=False): @@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, return attn_weight @ value, lse def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] - ref_O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + out_i, lse_i = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), is_causal=causal, ) - out[i] = ref_O.transpose(0, 1) - lse[i] = LSE + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i return out, lse out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + - b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " - f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( + b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", + f"{bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index 29c5e70a8ba85..ec5c60fd7b0e2 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, GeluAndMul, MulAndSilu, NewGELU, QuickGELU, - SiluAndMul) + SiluAndMul, SwigluOAIAndMul) from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,7 +25,15 @@ CUDA_DEVICES = [ @pytest.mark.parametrize( "activation", - ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]) + [ + "silu_and_mul", + "mul_and_silu", + "gelu", + "gelu_tanh", + "fatrelu", + "swigluoai_and_mul", + ], +) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -59,18 +67,43 @@ def test_act_and_mul( threshold = random.uniform(0, 1) layer = FatreluAndMul(threshold) fn = torch.ops._C.fatrelu_and_mul + elif activation == "swigluoai_and_mul": + layer = SwigluOAIAndMul() + fn = torch.ops._C.swigluoai_and_mul out = layer(x) ref_out = layer.forward_native(x) - # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are - # equivalent to the native PyTorch implementations, so we can do exact - # comparison. - torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) + if activation == "swigluoai_and_mul": + + rtol = { + #For fp16, change the relative tolerance from 1e-3 to 2e-3 + torch.float16: + 2e-3, + torch.bfloat16: + 2e-2, + torch.float: + 1.3e-6 + } + + def _get_rtol(output) -> float: + return rtol[output.dtype] + + torch.testing.assert_close(out, + ref_out, + atol=get_default_atol(out), + rtol=_get_rtol(out)) + else: + # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are + # equivalent to the native PyTorch implementations, so we can do exact + # comparison. + torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "fatrelu": opcheck(fn, (out, x, threshold)) + elif activation == "swigluoai_and_mul": + opcheck(fn, (out, x, layer.alpha, layer.limit)) else: opcheck(fn, (out, x)) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index fd99e8dc5c987..a10666b6ec9a7 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -7,41 +7,22 @@ import torch import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8 +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from tests.kernels.utils import torch_experts from vllm.config import VllmConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size -# Fused experts and PrepareFinalize imports -from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, - TritonExperts) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from .mk_objects import (expert_info, make_fused_experts, + make_prepare_finalize, prepare_finalize_info) from .parallel_utils import ProcessGroupInfo -from .utils import (make_block_quant_fp8_weights, make_non_quant_weights, - make_quant_fp8_weights, per_token_cast_to_fp8) - -if has_pplx(): - from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) -if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: @@ -69,24 +50,31 @@ class Config: torch_trace_dir_path: Optional[str] = None + def __post_init__(self): + if self.quant_config is None: + self.quant_config = FusedMoEQuantConfig() + def describe(self) -> str: s = "" - s += "== Config: \n" - s += f" world_size={self.world_size} \n" - s += f" PF={self.prepare_finalize_type.__name__} \n" - s += f" FE={self.fused_experts_type.__name__} \n" - s += f" topk={self.topks} \n" - s += f" dtype={self.dtype} \n" - s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n" - s += " Quant: \n" - s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n " + s += "== Config:\n" + s += f" world_size={self.world_size}\n" + s += f" PF={self.prepare_finalize_type.__name__}\n" + s += f" FE={self.fused_experts_type.__name__}\n" + s += f" E={self.E}\n" + s += f" Ms={self.Ms}\n" + s += f" N={self.N}\n" + s += f" K={self.K}\n" + s += f" topk={self.topks}\n" + s += f" dtype={self.dtype}\n" + s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n" + s += " Quant:\n" if self.quant_config is not None: - s += f" q_dtype={self.quant_dtype} \n" - s += f" q_block_shape={self.quant_block_shape} \n" - s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n" - s += f" q_per_act_token={self.is_per_act_token_quant} \n" + s += f" q_dtype={self.quant_dtype}\n" + s += f" q_block_shape={self.quant_block_shape}\n" + s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n" + s += f" q_per_act_token={self.is_per_act_token_quant}\n" else: - s += " quant=None \n" + s += " quant=None\n" return s @property @@ -95,34 +83,28 @@ class Config: return self.Ms @property - def quant_dtype(self) -> Optional[torch.dtype]: - if self.quant_config is None: - return None + def quant_dtype(self) -> Union[torch.dtype, str, None]: + assert self.quant_config is not None return self.quant_config.quant_dtype @property def is_per_act_token_quant(self) -> bool: - if self.quant_config is None: - return False + assert self.quant_config is not None return self.quant_config.per_act_token_quant @property def is_per_tensor_act_quant(self) -> bool: - if self.quant_config is None: - return False return (not self.is_per_act_token_quant and self.quant_block_shape is None) @property def is_per_out_ch_quant(self) -> bool: - if self.quant_config is None: - return False + assert self.quant_config is not None return self.quant_config.per_out_ch_quant @property def quant_block_shape(self) -> Optional[list[int]]: - if self.quant_config is None: - return None + assert self.quant_config is not None return self.quant_config.block_shape @property @@ -130,36 +112,30 @@ class Config: assert isinstance(self.topks, int) return self.topks - @property - def topk_ids_dtype(self) -> Optional[torch.dtype]: - topk_ids_dtype = None - if self.prepare_finalize_type == PplxPrepareAndFinalize: - topk_ids_dtype = torch.uint32 - elif self.prepare_finalize_type in [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ]: - topk_ids_dtype = torch.int64 - return topk_ids_dtype - @property def num_local_experts(self) -> int: return self.E // self.world_size def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: """ - make env data for vllm launch. + make env data for vllm launch. """ vllm_config = VllmConfig() vllm_config.parallel_config.data_parallel_size = self.world_size vllm_config.parallel_config.enable_expert_parallel = True env_dict = { - "VLLM_ALL2ALL_BACKEND": self.all2all_backend(), "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())), } + + backend = self.all2all_backend() + if backend is not None: + env_dict.update({"VLLM_ALL2ALL_BACKEND": backend}) + if self.fused_moe_chunk_size is not None: env_dict.update( {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) + return vllm_config, env_dict def is_fp8_block_quantized(self): @@ -167,85 +143,59 @@ class Config: and self.quant_block_shape is not None) def is_batched_prepare_finalize(self): - return self.prepare_finalize_type in [ - PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return (mk.FusedMoEActivationFormat.BatchedExperts == + info.activation_format) def is_batched_fused_experts(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts, - NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts - ] + info = expert_info(self.fused_experts_type) + return (mk.FusedMoEActivationFormat.BatchedExperts == + info.activation_format) def is_standard_fused_experts(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, - TritonExperts - ] + info = expert_info(self.fused_experts_type) + return mk.FusedMoEActivationFormat.Standard == info.activation_format - def is_fe_16bit_supported(self): - return self.fused_experts_type in [ - BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, - NaiveBatchedExperts, TritonExperts - ] + def fe_supported_types(self): + info = expert_info(self.fused_experts_type) + return info.supported_dtypes - def is_fe_fp8_supported(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - BatchedTritonExperts, - BatchedTritonOrDeepGemmExperts, - CutlassExpertsFp8, - DeepGemmExperts, - TritonExperts, - TritonOrDeepGemmExperts, - NaiveBatchedExperts, - ] + def pf_supported_types(self): + info = prepare_finalize_info(self.prepare_finalize_type) + return info.supported_dtypes - def is_fe_block_fp8_supported(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - BatchedTritonOrDeepGemmExperts, - DeepGemmExperts, - TritonExperts, - TritonOrDeepGemmExperts, - BatchedTritonExperts, - NaiveBatchedExperts, - ] + def is_block_quant_supported(self): + info = expert_info(self.fused_experts_type) + return info.blocked_quantization_support def is_fe_supports_chunking(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, - TritonExperts - ] + info = expert_info(self.fused_experts_type) + return info.supports_chunking + + def supports_expert_map(self): + info = expert_info(self.fused_experts_type) + return info.supports_expert_map + + def supports_apply_weight_on_input(self): + info = prepare_finalize_info(self.prepare_finalize_type) + return info.supports_apply_weight_on_input def needs_deep_gemm(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - DeepGemmExperts, - ] + info = expert_info(self.fused_experts_type) + return info.needs_deep_gemm def needs_pplx(self): - return self.prepare_finalize_type in [PplxPrepareAndFinalize] + info = prepare_finalize_info(self.prepare_finalize_type) + return info.backend == "pplx" def needs_deep_ep(self): - return self.prepare_finalize_type in [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return (info.backend == "deepep_high_throughput" + or info.backend == "deepep_low_latency") def all2all_backend(self): - if self.needs_pplx(): - return "pplx" - if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize: - return "deepep_high_throughput" - if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize: - return "deepep_low_latency" - return "naive" - - def needs_all2all(self): - return self.prepare_finalize_type in [ - PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize, - DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return info.backend def is_valid(self): # Check prepare-finalize and fused-experts compatibility @@ -267,28 +217,28 @@ class Config: # invalid quant config return False - # check bf16 / fp16 support - is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None) - if is_16bit and not self.is_fe_16bit_supported(): - return False + # check type support + if self.quant_dtype is None: + if (self.dtype not in self.pf_supported_types() + or self.dtype not in self.fe_supported_types()): + return False + else: + if (self.quant_dtype not in self.pf_supported_types() + or self.quant_dtype not in self.fe_supported_types()): + return False - # Check fp8 support - is_fp8 = self.quant_dtype == torch.float8_e4m3fn - if is_fp8 and not self.is_fe_fp8_supported(): - return False - - # Check fp8 block quanization support + # Check block quanization support is_block_quatized = self.quant_block_shape is not None - if is_block_quatized and not is_fp8: + if is_block_quatized and self.quant_dtype is None: return False - if is_block_quatized and not self.is_fe_block_fp8_supported(): + if is_block_quatized and not self.is_block_quant_supported(): return False # deep_gemm only works with block-quantized if self.needs_deep_gemm() and not is_block_quatized: return False - # Check dependencies + # Check dependencies (turn into asserts?) if self.needs_deep_ep() and not has_deep_ep(): return False if self.needs_deep_gemm() and not has_deep_gemm(): @@ -305,6 +255,8 @@ class WeightTensors: w2: torch.Tensor w1_scale: Optional[torch.Tensor] w2_scale: Optional[torch.Tensor] + w1_gs: Optional[torch.Tensor] = None + w2_gs: Optional[torch.Tensor] = None def describe(self): s = "" @@ -313,13 +265,20 @@ class WeightTensors: s += f' - {_describe_tensor(self.w2, "w2")} \n' s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' + s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n' + s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n' return s + def is_quantized(self) -> bool: + # or w1_scale is not None? + return (self.w1.dtype == torch.float8_e4m3fn + or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8) + def to_current_device(self): self.w1 = self.w1.to(device=torch.cuda.current_device()) self.w2 = self.w2.to(device=torch.cuda.current_device()) - is_quantized = self.w1.dtype == torch.float8_e4m3fn - if is_quantized: + + if self.is_quantized(): assert self.w1_scale is not None assert self.w2_scale is not None self.w1_scale = self.w1_scale.to( @@ -327,56 +286,51 @@ class WeightTensors: self.w2_scale = self.w2_scale.to( device=torch.cuda.current_device()) + if self.w1_gs is not None: + assert self.w2_gs is not None + self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device()) + self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device()) + def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors": s = rank * num_local_experts e = s + num_local_experts w1 = self.w1[s:e, :, :] w2 = self.w2[s:e, :, :] - is_quantized = self.w1.dtype == torch.float8_e4m3fn + w1_scale, w2_scale = (None, None) - if is_quantized: + if self.is_quantized(): assert self.w1_scale is not None assert self.w2_scale is not None w1_scale = self.w1_scale[s:e, :, :] w2_scale = self.w2_scale[s:e, :, :] - return WeightTensors(w1, w2, w1_scale, w2_scale) + + w1_gs = self.w1_gs + w2_gs = self.w2_gs + if w1_gs is not None: + assert w2_gs is not None + w1_gs = w1_gs[s:e] + w2_gs = w2_gs[s:e] + + return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs) @staticmethod def make(config: Config) -> "WeightTensors": - - if config.quant_dtype is None: - # just make normal dtype weights - w1, w2 = make_non_quant_weights(e=config.E, - n=config.N, - k=config.K, - dtype=config.dtype) - return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None) - - assert config.quant_dtype == torch.float8_e4m3fn - if not config.is_fp8_block_quantized(): - w1, w2, w1_scale, w2_scale = make_quant_fp8_weights( - e=config.E, - n=config.N, - k=config.K, - per_out_channel_quant=config.is_per_out_ch_quant, - ) - return WeightTensors(w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale) - - assert config.quant_block_shape is not None - w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( + (_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights( e=config.E, n=config.N, k=config.K, - block_size=config.quant_block_shape, + in_dtype=config.dtype, + quant_dtype=config.quant_dtype, + block_shape=config.quant_block_shape, + per_act_token_quant=config.is_per_out_ch_quant, ) return WeightTensors(w1=w1, w2=w2, w1_scale=w1_scale, - w2_scale=w2_scale) + w2_scale=w2_scale, + w1_gs=w1_gs, + w2_gs=w2_gs) @dataclass @@ -449,7 +403,6 @@ class RankTensors: dtype=dtype) topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False) - topk_ids = topk_ids.to(config.topk_ids_dtype) # distribute topk_ids evenly for mi in range(m): @@ -457,7 +410,7 @@ class RankTensors: topk_ids = topk_ids.to(device=torch.cuda.current_device()) expert_map = None - if config.world_size > 1: + if config.world_size > 1 and config.supports_expert_map(): expert_map = torch.full((global_num_experts, ), fill_value=-1, dtype=torch.int32) @@ -480,92 +433,100 @@ class RankTensors: def reference_moe_impl(config: Config, weights: WeightTensors, rank_tensors: RankTensors) -> torch.Tensor: - return torch_experts(a=rank_tensors.hidden_states, - w1=weights.w1, - w2=weights.w2, + if config.quant_dtype == "nvfp4": + quant_blocksize = 16 + dtype = config.dtype + + w1_q = weights.w1 + w1_blockscale = weights.w1_scale + w1_gs = weights.w1_gs + + w2_q = weights.w2 + w2_blockscale = weights.w2_scale + w2_gs = weights.w2_gs + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax( + rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32) + + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + + assert w1_blockscale.shape[1] % 128 == 0 + assert w1_blockscale.shape[2] % 4 == 0 + assert w2_blockscale.shape[1] % 128 == 0 + assert w2_blockscale.shape[2] % 4 == 0 + + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant( + rank_tensors.hidden_states, a_global_scale) + + a = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=dtype, + device=a_fp4.device, + block_size=quant_blocksize) + + e = w1_q.shape[0] + n = w1_q.shape[1] // 2 + k = w2_q.shape[1] + + w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype) + w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize) + w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize) + a_scale = None + w1_scale = None + w2_scale = None + quant_dtype = None + per_act_token_quant = False + block_shape = None + else: + a = rank_tensors.hidden_states + a_scale = rank_tensors.hidden_states_scale + w1 = weights.w1 + w1_scale = weights.w1_scale + w2 = weights.w2 + w2_scale = weights.w2_scale + quant_dtype = config.quant_dtype + per_act_token_quant = config.is_per_act_token_quant + block_shape = config.quant_block_shape + + return torch_experts(a=a, + w1=w1, + w2=w2, topk_weight=rank_tensors.topk_weights, topk_ids=rank_tensors.topk_ids, global_num_experts=config.E, expert_map=None, - w1_scale=weights.w1_scale, - w2_scale=weights.w2_scale, - a1_scale=rank_tensors.hidden_states_scale, - quant_dtype=config.quant_dtype, - per_act_token_quant=config.is_per_act_token_quant, - block_shape=config.quant_block_shape, - apply_router_weights_on_input=config.topk == 1) + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + apply_router_weights_on_input=config.topk == 1 + and config.supports_apply_weight_on_input()) -def make_fused_experts( - config: Config, moe: FusedMoEConfig, - num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute: - - use_fp8 = config.quant_dtype == torch.float8_e4m3fn - batch_kwargs = { - "max_num_tokens": moe.max_num_tokens, - "num_dispatchers": num_dispatchers, - } - quant_kwargs = { - "use_fp8_w8a8": use_fp8, - "use_int8_w8a8": False, - "use_int8_w8a16": False, - "use_int4_w4a16": False, - "block_shape": config.quant_block_shape, - "per_act_token_quant": config.is_per_act_token_quant, - } - deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} - - if config.fused_experts_type == BatchedDeepGemmExperts: - kwargs = batch_kwargs | { - "block_shape": config.quant_block_shape, - "per_act_token_quant": config.is_per_act_token_quant, - } - print(f"Making BatchedDeepGemmExperts {kwargs} ...") - experts = BatchedDeepGemmExperts(**kwargs) - elif config.fused_experts_type == BatchedTritonExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making BatchedTritonExperts {kwargs} ...") - experts = BatchedTritonExperts(**kwargs) - elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts: - kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs - print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") - experts = BatchedTritonOrDeepGemmExperts(**kwargs) - elif config.fused_experts_type == DeepGemmExperts: - print("Making DeepGemmExperts () ...") - experts = DeepGemmExperts() - elif config.fused_experts_type == TritonExperts: - kwargs = quant_kwargs - print(f"Making TritonExperts {kwargs} ...") - experts = TritonExperts(**kwargs) - elif config.fused_experts_type == TritonOrDeepGemmExperts: - kwargs = quant_kwargs | deepgemm_kwargs - print(f"Making TritonOrDeepGemmExperts {kwargs} ...") - experts = TritonOrDeepGemmExperts(**kwargs) - elif config.fused_experts_type == NaiveBatchedExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making NaiveBatchedExperts {kwargs} ...") - experts = NaiveBatchedExperts(**kwargs) - elif config.fused_experts_type == CutlassExpertsFp8: - use_batched_format = config.is_batched_prepare_finalize() - num_experts = (moe.num_local_experts - if use_batched_format else moe.num_experts) - kwargs = { - "max_experts_per_worker": num_experts, - "out_dtype": moe.in_dtype, - "per_act_token_quant": config.is_per_act_token_quant, - "per_out_ch_quant": config.is_per_out_ch_quant, - "block_shape": config.quant_block_shape, - "num_dispatchers": num_dispatchers, - "use_batched_format": use_batched_format - } - print(f"Making CutlassExpertsFp8 {kwargs} ...") - experts = CutlassExpertsFp8(**kwargs) - - return experts - - -def make_modular_kernel(config: Config, - vllm_config: VllmConfig) -> mk.FusedMoEModularKernel: +def make_modular_kernel( + config: Config, + vllm_config: VllmConfig, + weights: WeightTensors, +) -> mk.FusedMoEModularKernel: def next_power_of_2(x): import math @@ -579,6 +540,7 @@ def make_modular_kernel(config: Config, dp_size_=get_dp_group().world_size, vllm_parallel_config=vllm_config.parallel_config, ) + moe = FusedMoEConfig( num_experts=config.E, experts_per_token=config.topk, @@ -591,15 +553,16 @@ def make_modular_kernel(config: Config, ) # make modular kernel - prepare_finalize = None - if config.needs_all2all(): - prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe) - assert prepare_finalize is not None - else: - prepare_finalize = MoEPrepareAndFinalizeNoEP() + prepare_finalize = make_prepare_finalize(config.prepare_finalize_type, + config.all2all_backend(), moe) - fused_experts = make_fused_experts(config, moe, - prepare_finalize.num_dispatchers()) + fused_experts = make_fused_experts( + config.fused_experts_type, + moe, + prepare_finalize.num_dispatchers(), + weights.w1_gs, + weights.w2_gs, + ) modular_kernel = mk.FusedMoEModularKernel( prepare_finalize=prepare_finalize, fused_experts=fused_experts) @@ -620,22 +583,45 @@ def run_modular_kernel( # weights for rank rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) - mk = make_modular_kernel(config, vllm_config) + mk = make_modular_kernel(config, vllm_config, weights) mk_kwargs = { - "hidden_states": rank_tensors.hidden_states.clone( + "hidden_states": + rank_tensors.hidden_states.clone( ), # impls might update the tensor in place - "w1": rank_weights.w1, - "w2": rank_weights.w2, - "topk_weights": rank_tensors.topk_weights, - "topk_ids": rank_tensors.topk_ids, - "expert_map": rank_tensors.expert_map, - "w1_scale": rank_weights.w1_scale, - "w2_scale": rank_weights.w2_scale, - "a1_scale": rank_tensors.hidden_states_scale, - "global_num_experts": config.E, - "apply_router_weight_on_input": config.topk == 1, + "w1": + rank_weights.w1, + "w2": + rank_weights.w2, + "topk_weights": + rank_tensors.topk_weights, + "topk_ids": + rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()), + "expert_map": + rank_tensors.expert_map, + "w1_scale": + rank_weights.w1_scale, + "w2_scale": + rank_weights.w2_scale, + "a1_scale": + rank_tensors.hidden_states_scale, + "global_num_experts": + config.E, + "apply_router_weight_on_input": + config.topk == 1 and config.supports_apply_weight_on_input(), } - out = mk.forward(**mk_kwargs) + + num_tokens = rank_tensors.hidden_states.shape[0] + num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size, + device="cuda", + dtype=torch.int) + + with set_forward_context( + None, + vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ): + out = mk.forward(**mk_kwargs) return out diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 73214066f7ea6..aecffae36ae5e 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -1,58 +1,316 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional, Union import torch # Fused experts and PrepareFinalize imports +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 BatchedTritonOrDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, NaiveBatchedExperts) -from vllm.model_executor.layers.fused_moe.layer import TritonExperts +from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, + TritonExperts) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) -from vllm.utils import has_deep_ep, has_pplx +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_fp8_supported) +from vllm.platforms import current_platform +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -if has_deep_ep(): + +@dataclass +class PrepareFinalizeInfo: + activation_format: mk.FusedMoEActivationFormat + supported_dtypes: list[Union[torch.dtype, str]] + blocked_quantization_support: bool + backend: Optional[str] + supports_apply_weight_on_input: bool = True + + +@dataclass +class ExpertInfo: + activation_format: mk.FusedMoEActivationFormat + supported_dtypes: list[Union[torch.dtype, str]] + blocked_quantization_support: bool + supports_chunking: bool + supports_expert_map: bool + needs_matching_quant: bool = False + needs_deep_gemm: bool = False + + +PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, + PrepareFinalizeInfo] = {} +EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {} +MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = [] + +standard_format = mk.FusedMoEActivationFormat.Standard +batched_format = mk.FusedMoEActivationFormat.BatchedExperts +common_float_types: list[Union[torch.dtype, str]] = [ + torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32 +] +common_float_and_int_types = common_float_types + [torch.int8] +nv_fp4_types = ["nvfp4"] +fp8_types = [torch.float8_e4m3fn] + + +def register_prepare_and_finalize( + kind, + activation_format: mk.FusedMoEActivationFormat, + supported_dtypes: list[Union[torch.dtype, str]], + blocked_quantization_support: bool, + backend: Optional[str], + force_multigpu: bool = False, + supports_apply_weight_on_input: bool = True, +): + global PREPARE_FINALIZE_INFO + global MK_ALL_PREPARE_FINALIZE_TYPES + global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES + assert kind not in PREPARE_FINALIZE_INFO + + PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo( + activation_format, + supported_dtypes, + blocked_quantization_support, + backend, + supports_apply_weight_on_input, + ) + MK_ALL_PREPARE_FINALIZE_TYPES.append(kind) + if backend is not None or force_multigpu: + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind) + else: + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind) + + +def register_experts( + kind, + activation_format: mk.FusedMoEActivationFormat, + supported_dtypes: list[Union[torch.dtype, str]], + blocked_quantization_support: bool, + supports_chunking: bool, + supports_expert_map: bool, + needs_matching_quant: bool = False, + needs_deep_gemm: bool = False, +): + global EXPERT_INFO + global MK_FUSED_EXPERT_TYPES + assert kind not in EXPERT_INFO + + EXPERT_INFO[kind] = ExpertInfo( + activation_format, + supported_dtypes, + blocked_quantization_support, + supports_chunking, + supports_expert_map, + needs_matching_quant, + needs_deep_gemm, + ) + + MK_FUSED_EXPERT_TYPES.append(kind) + + +def prepare_finalize_info(kind) -> PrepareFinalizeInfo: + info = PREPARE_FINALIZE_INFO.get(kind) + assert info is not None + return info + + +def expert_info(kind) -> ExpertInfo: + info = EXPERT_INFO.get(kind) + assert info is not None + return info + + +register_prepare_and_finalize( + MoEPrepareAndFinalizeNoEP, + standard_format, + common_float_types, + blocked_quantization_support=True, + backend=None, +) + +register_experts( + BatchedTritonExperts, + batched_format, + common_float_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=True, +) + +register_experts( + TritonExperts, + standard_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=True, +) + +register_experts( + NaiveBatchedExperts, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=True, +) + +# Disable on blackwell for now +if has_deep_ep() and not current_platform.has_device_capability(100): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) + register_prepare_and_finalize( + DeepEPHTPrepareAndFinalize, + standard_format, + common_float_types, + blocked_quantization_support=True, + backend="deepep_high_throughput", + ) + + register_prepare_and_finalize( + DeepEPLLPrepareAndFinalize, + batched_format, + common_float_types, + blocked_quantization_support=True, + backend="deepep_low_latency", + ) + if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) + register_prepare_and_finalize( + PplxPrepareAndFinalize, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + backend="pplx", + ) -MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] -if has_pplx(): - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] -if has_deep_ep(): - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] +if (has_flashinfer_cutlass_fused_moe() + and current_platform.has_device_capability(100)): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts) + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) -MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] + register_prepare_and_finalize( + FlashInferCutlassMoEPrepareAndFinalize, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + backend=None, + force_multigpu=True, + supports_apply_weight_on_input=False, + ) -MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + - MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + register_experts( + FlashInferExperts, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + supports_chunking=True, + # Note: this is a hack to get it to run for now + supports_expert_map=True, + ) +else: + FlashInferCutlassMoEPrepareAndFinalize = None -MK_FUSED_EXPERT_TYPES = [ - BatchedDeepGemmExperts, - BatchedTritonExperts, - NaiveBatchedExperts, - BatchedTritonOrDeepGemmExperts, - CutlassExpertsFp8, - DeepGemmExperts, - TritonOrDeepGemmExperts, - TritonExperts, -] +if has_deep_gemm() and is_deep_gemm_supported(): + register_experts( + BatchedDeepGemmExperts, + batched_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=False, + needs_deep_gemm=True, + ) + register_experts( + DeepGemmExperts, + standard_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=False, + needs_deep_gemm=True, + ), + register_experts( + BatchedTritonOrDeepGemmExperts, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=True, + needs_deep_gemm=True, + ) + register_experts( + TritonOrDeepGemmExperts, + standard_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=True, + needs_deep_gemm=True, + ) + +if cutlass_fp8_supported(): + from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8, + CutlassExpertsFp8) + register_experts( + CutlassExpertsFp8, + standard_format, + fp8_types, + blocked_quantization_support=False, + supports_chunking=True, + supports_expert_map=False, + ) + register_experts( + CutlassBatchedExpertsFp8, + batched_format, + fp8_types, + blocked_quantization_support=False, + supports_chunking=False, + supports_expert_map=False, + ) + +if cutlass_fp4_supported(): + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4) + register_experts( + CutlassExpertsFp4, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=False, + ) MK_QUANT_CONFIGS = [ None, @@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [ # block-quantized weights and per-token activations # block-quantized weights and per-tensor activations ] + +if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): + MK_QUANT_CONFIGS += [ + FusedMoEQuantConfig(quant_dtype="nvfp4", + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), + ] + + +def _make_gscale(num_experts: int) -> torch.Tensor: + return torch.ones((num_experts, ), + device=torch.cuda.current_device(), + dtype=torch.float32) + + +def make_prepare_finalize( + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + backend: Optional[str], + moe: FusedMoEConfig, +) -> mk.FusedMoEPrepareAndFinalize: + if backend != "naive" and backend is not None: + prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + assert prepare_finalize is not None + return prepare_finalize + elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: + return FlashInferCutlassMoEPrepareAndFinalize( + use_dp=moe.moe_parallel_config.dp_size > 1, + a1_gscale=_make_gscale(moe.num_local_experts), + ) + else: + return MoEPrepareAndFinalizeNoEP() + + +def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: + s = rank * num_local_experts + e = s + num_local_experts + return t[s:e] + + +def make_fused_experts( + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + moe: FusedMoEConfig, + num_dispatchers: int, + w1_gs: Optional[torch.Tensor], + w2_gs: Optional[torch.Tensor], +) -> mk.FusedMoEPermuteExpertsUnpermute: + + use_fp8 = moe.quant_dtype == torch.float8_e4m3fn + batch_kwargs = { + "max_num_tokens": moe.max_num_tokens, + "num_dispatchers": num_dispatchers, + } + quant_kwargs = { + "use_fp8_w8a8": use_fp8, + "use_int8_w8a8": False, + "use_int8_w8a16": False, + "use_int4_w4a16": False, + "block_shape": moe.block_shape, + "per_act_token_quant": moe.per_act_token_quant, + } + deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + + if fused_experts_type == BatchedDeepGemmExperts: + kwargs = batch_kwargs | { + "block_shape": moe.block_shape, + "per_act_token_quant": moe.per_act_token_quant, + } + print(f"Making BatchedDeepGemmExperts {kwargs} ...") + experts = BatchedDeepGemmExperts(**kwargs) + elif fused_experts_type == BatchedTritonExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making BatchedTritonExperts {kwargs} ...") + experts = BatchedTritonExperts(**kwargs) + elif fused_experts_type == BatchedTritonOrDeepGemmExperts: + kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs + print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") + experts = BatchedTritonOrDeepGemmExperts(**kwargs) + elif fused_experts_type == DeepGemmExperts: + print("Making DeepGemmExperts () ...") + experts = DeepGemmExperts() + elif fused_experts_type == TritonExperts: + kwargs = quant_kwargs + print(f"Making TritonExperts {kwargs} ...") + experts = TritonExperts(**kwargs) + elif fused_experts_type == TritonOrDeepGemmExperts: + kwargs = quant_kwargs | deepgemm_kwargs + print(f"Making TritonOrDeepGemmExperts {kwargs} ...") + experts = TritonOrDeepGemmExperts(**kwargs) + elif fused_experts_type == NaiveBatchedExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making NaiveBatchedExperts {kwargs} ...") + experts = NaiveBatchedExperts(**kwargs) + elif fused_experts_type == CutlassExpertsFp8: + kwargs = { + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + } + print(f"Making CutlassExpertsFp8 {kwargs} ...") + experts = CutlassExpertsFp8(**kwargs) + elif fused_experts_type == CutlassBatchedExpertsFp8: + kwargs = { + "max_experts_per_worker": moe.num_local_experts, + "num_dispatchers": num_dispatchers, + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + } + print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...") + experts = CutlassBatchedExpertsFp8(**kwargs) + elif fused_experts_type == CutlassExpertsFp4: + assert w1_gs is not None and w2_gs is not None + num_experts = moe.num_local_experts + rank = moe.moe_parallel_config.dp_rank + kwargs = { + "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), + "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), + "a1_gscale": _make_gscale(num_experts), + "a2_gscale": _make_gscale(num_experts), + "max_experts_per_worker": num_experts, + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + "num_dispatchers": num_dispatchers, + } + print(f"Making CutlassExpertsFp4 {kwargs} ...") + experts = CutlassExpertsFp4(**kwargs) + elif fused_experts_type == FlashInferExperts: + assert w1_gs is not None and w2_gs is not None + num_experts = moe.num_local_experts + rank = moe.moe_parallel_config.dp_rank + kwargs = { + "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), + "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), + "a1_gscale": _make_gscale(num_experts), + "a2_gscale": _make_gscale(num_experts), + "out_dtype": moe.in_dtype, + "quant_dtype": "nvfp4", + "ep_rank": moe.ep_rank, + "ep_size": moe.ep_size, + "tp_rank": moe.tp_rank, + "tp_size": moe.tp_size, + } + print(f"Making FlashInferExperts {kwargs} ...") + experts = FlashInferExperts(**kwargs) + else: + raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}") + + return experts diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py index dd16ffb2eabec..0da6ee3543521 100644 --- a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -52,7 +52,7 @@ def profile_modular_kernel( rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) # make modular kernel - mk = make_modular_kernel(config, vllm_config) + mk = make_modular_kernel(config, vllm_config, weights) mk_kwargs = { "hidden_states": rank_tensors.hidden_states, @@ -83,7 +83,7 @@ def rank_worker( # sanity check from vllm import envs if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() diff --git a/tests/kernels/moe/modular_kernel_tools/utils.py b/tests/kernels/moe/modular_kernel_tools/utils.py deleted file mode 100644 index 866f52882beee..0000000000000 --- a/tests/kernels/moe/modular_kernel_tools/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm._custom_ops as ops -from vllm.utils.deep_gemm import per_block_cast_to_fp8 - - -def per_token_cast_to_fp8( - x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - pad_size = (block_size - (n % block_size)) % block_size - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x - x_view = x.view(m, -1, block_size) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) - return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) - - -def make_non_quant_weights( - e: int, - n: int, - k: int, - dtype: torch.dtype, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Return weights w1, w2 - """ - device = torch.cuda.current_device() - w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15 - w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15 - return w1, w2 - - -def make_block_quant_fp8_weights( - e: int, - n: int, - k: int, - block_size: list[int], -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Return weights w1, w2, w1_scale, w2_scale - """ - dtype = torch.bfloat16 - device = torch.cuda.current_device() - - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype) - w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) - w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = ((2 * n) + block_n - 1) // block_n - k_tiles_w1 = (k + block_k - 1) // block_k - n_tiles_w2 = (k + block_n - 1) // block_n - k_tiles_w2 = (n + block_k - 1) // block_k - - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device) - - w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), - device=device, - dtype=torch.float32) - w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), - device=device, - dtype=torch.float32) - - assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n, - (k + (block_k - 1)) // block_k) - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - - for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], - block_size=[block_k, block_n]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], - block_size=[block_k, block_n]) - - return w1, w2, w1_s, w2_s - - -def make_quant_fp8_weights( - e: int, - n: int, - k: int, - per_out_channel_quant: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Return w1, w2, w1_scale, w2_scale - """ - q_dtype = torch.float8_e4m3fn - - w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16) - - # w1 -> w1_q, w2 -> w2_q - w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) - w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) - - n_b_scales = 2 * n if per_out_channel_quant else 1 - k_b_scales = k if per_out_channel_quant else 1 - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_channel_quant) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_channel_quant) - return w1_q, w2_q, w1_scale, w2_scale diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index edf3e61892430..00b2d780e66f5 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, per_act_token_quant=per_act_token_quant, ) - B, B_q, B_scale, _, _, _ = make_test_weights( + (B, B_q, B_scale, _), _ = make_test_weights( num_experts, N // 2, K, @@ -243,7 +243,7 @@ def test_fused_moe_batched_experts( act_dtype = dtype quant_dtype = None - w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights( + (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights( e, n, k, diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 75b2e9f791789..ecc57acc67963 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used dg_available = has_deep_gemm() @@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size) m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, + use_mxfp4_w4a4=False, per_act_token_quant=False, block_shape=block_size) @@ -224,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), - reason="Not E8M0 scale MOE") +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): @@ -247,13 +248,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 8e680c722935b..5e4a93963f8e8 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -118,13 +118,14 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.int8, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.int8, + per_act_token_quant=False, + block_shape=block_size) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 1aee1ed8c3762..3b1618dacac7b 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -9,6 +9,7 @@ import random import pytest import torch +from tests.kernels.moe.utils import per_token_cast_to_fp8 from tests.kernels.utils import baseline_scaled_mm from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -16,20 +17,6 @@ from vllm.utils import cdiv from vllm.utils.deep_gemm import per_block_cast_to_fp8 -def per_token_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - pad_size = (128 - (n % 128)) % 128 - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * - (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) - return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) - - @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ (4, 8192, 7168, 4096), (4, 8192, 2048, 7168), @@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm( device=device, dtype=torch.float)) for i in range(num_groups): - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128]) for i in range(num_groups): a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 81fb3ec1de188..c84f66383b902 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'topk_ids': topk_ids, 'w1_scale': moe_tensors.w1_scale, 'w2_scale': moe_tensors.w2_scale, + 'ab_strides1': moe_tensors.ab_strides1, + 'ab_strides2': moe_tensors.ab_strides2, + 'c_strides1': moe_tensors.c_strides1, + 'c_strides2': moe_tensors.c_strides2, 'per_act_token': per_act_token, 'a1_scale': None #moe_tensors.a_scale } @@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8( topk_ids[0][1] = 1 workspace13_shape = (m * topk, max(2 * n, k)) - workspace2_shape = (m * topk, n) - output_shape = (m * topk, k) + workspace2_shape = (m * topk, max(n, k)) + output_shape = (m, k) workspace13 = torch.empty(prod(workspace13_shape), device="cuda", @@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8( expert_map[start:end] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, torch.float8_e4m3fn, @@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8( func = lambda output: run_cutlass_moe_fp8( output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, - a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, - per_act_token, per_out_channel, False) + a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, + workspace13, workspace2, None, mt.a.dtype, per_act_token, + per_out_channel, False, topk_weights) workspace13.random_() output_random_workspace = torch.empty(output_shape, diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 9b064db973ddf..36a98522a6588 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch from .utils import make_test_weights @@ -70,8 +70,10 @@ def make_block_quant_fp8_weights( """ Return weights w1q, w2q, w1_scale, w2_scale """ - w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights( - e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) + (_, w1q, w1_scale, _), (_, w2q, w2_scale, + _) = make_test_weights(e, n, k, torch.bfloat16, + torch.float8_e4m3fn, + block_size) return w1q, w2q, w1_scale, w2_scale @@ -368,9 +370,10 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): @@ -425,9 +428,10 @@ USE_FP8_DISPATCH = [False] @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @pytest.mark.parametrize("block_size", [[128, 128]]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 43804c410b6c2..6a53af68cd53a 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.platforms import current_platform from vllm.utils import has_deep_ep +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch if has_deep_ep(): @@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep def test_deep_ep_moe( dtype: torch.dtype, @@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False] @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) +@multi_gpu_test(num_gpus=2) @requires_deep_ep def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], num_experts: int, topk: int, diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index b2b78662c9ded..4472f34a6291a 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size): # Note: W1 has shape (E, 2N, K), so N = 512 # can trigger the deepgemm path. MNKs = [ - (1024, 512, 128), - (1024, 512, 512), - (2048, 512, 512), + (1024, 768, 128), + (1024, 768, 512), + (2048, 768, 512), (512, 1024, 1024), (512, 2048, 2048), (4096, 4096, 1024), diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py new file mode 100644 index 0000000000000..52a3d2ca3b422 --- /dev/null +++ b/tests/kernels/moe/test_flashinfer.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + input_to_float8) +from vllm.model_executor.models.llama4 import Llama4MoE +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if not has_flashinfer_cutlass_fused_moe( +) or not current_platform.has_device_capability(100): + pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True) + +NUM_EXPERTS = [16] +TOP_KS = [1] + +MNK_FACTORS = [ + (256, 8192, 5120), + (256, 4096, 5120), + (127, 8192, 5120), + (127, 4096, 5120), + (10, 8192, 5120), + (10, 4096, 5120), + (1, 8192, 5120), + (1, 4096, 5120), +] + +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + + +def quant_fp8_per_tensor_batches(a): + num_batches = a.size(0) + a_quant = [] + a_scales = [] + + for i in range(num_batches): + a_fp8, a_global_sf = input_to_float8(a[i]) + a_global_sf = 1.0 / a_global_sf + a_quant.append(a_fp8) + a_scales.append(a_global_sf) + + result_a_quant = torch.stack(a_quant) + result_a_scales = torch.stack(a_scales) + + return result_a_quant, result_a_scales + + +@dataclass +class TestData: + hidden_states: torch.Tensor + w13_quantized: torch.Tensor + w2_quantized: torch.Tensor + a1_scale: torch.Tensor + a2_scale: torch.Tensor + w13_weight_scale: torch.Tensor + w2_weight_scale: torch.Tensor + layer: torch.nn.Module + + @staticmethod + def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, + reorder: bool) -> "TestData": + hidden_states = torch.randn( + (m, k), device="cuda", dtype=torch.bfloat16) / 10 + w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) + w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) + + # Scale to fp8 + _, a1_scale = input_to_float8(hidden_states) + a1_scale = 1.0 / a1_scale + a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to( + dtype=torch.float32) + w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13) + w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) + + layer = torch.nn.Module() + layer.w13_weight = w13_quantized.clone() + layer.w2_weight = w2_quantized.clone() + layer.w13_input_scale = a1_scale + layer.w2_input_scale = a2_scale + layer.w13_weight_scale = w13_weight_scale + layer.w2_weight_scale = w2_weight_scale + + register_moe_scaling_factors(layer) + + # flashinfer expects swapped rows for w13 + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + if reorder: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) + layer.custom_routing_function = Llama4MoE.custom_routing_function + layer.intermediate_size_per_partition = n + layer.ep_rank = 0 + layer.local_num_experts = e + + return TestData( + hidden_states=hidden_states, + w13_quantized=w13_quantized, + w2_quantized=w2_quantized, + a1_scale=a1_scale, + a2_scale=a2_scale, + w13_weight_scale=w13_weight_scale, + w2_weight_scale=w2_weight_scale, + layer=layer, + ) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_flashinfer_per_tensor_moe_fp8_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + monkeypatch, +): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + with set_current_vllm_config(vllm_config): + td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) + + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=td.hidden_states, + router_logits=score, + use_grouped_topk=False, + top_k=topk, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + scoring_func="softmax") + + output = fused_experts( + td.hidden_states, + td.w13_quantized, + td.w2_quantized, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=e, + expert_map=None, + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + apply_router_weight_on_input=True, + ) + + flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( + layer=td.layer, + hidden_states=td.hidden_states, + router_logits=score, + routing_bias=None, + global_num_experts=e, + top_k=topk, + num_expert_group=None, + topk_group=None, + apply_router_weight_on_input=True) + + torch.testing.assert_close(output, + flashinfer_output, + atol=5.5e-2, + rtol=1e-2) + + +@pytest.mark.skip( + "Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472" +) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_flashinfer_cutlass_moe_fp8_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + monkeypatch, +): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + with set_current_vllm_config(vllm_config): + td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False) + + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=td.hidden_states, + router_logits=score, + use_grouped_topk=False, + top_k=topk, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + scoring_func="softmax") + + output = fused_experts( + td.hidden_states, + td.w13_quantized, + td.w2_quantized, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=e, + expert_map=None, + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + apply_router_weight_on_input=True, + ) + + td.layer.dp_size = 1 + + flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8( + td.hidden_states, + td.layer, + topk_weights, + topk_ids, + activation="silu", + global_num_experts=e, + expert_map=None, + apply_router_weight_on_input=True, + ) + + torch.testing.assert_close(output, + flashinfer_cutlass_output, + atol=5.5e-2, + rtol=1e-2) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py new file mode 100644 index 0000000000000..1c14df2b914aa --- /dev/null +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.moe.utils import make_test_weights +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) +from tests.kernels.utils import torch_moe +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if not has_flashinfer_cutlass_fused_moe( +) or not current_platform.has_device_capability(100): + pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True) + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1536), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +#@pytest.mark.parametrize("e", [128, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + + quant_blocksize = 16 + + (_, w1_q, w1_blockscale, + w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, # use quant_blocksize? + per_act_token_quant=False, + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, + score, + topk, + renormalize=False) + + a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + + assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) + + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + + flashinfer_experts = FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + FlashInferExperts( + a1_gscale=a1_gs, + g1_alphas=(1 / w1_gs), + a2_gscale=a2_gs, + g2_alphas=(1 / w2_gs), + out_dtype=dtype, + quant_dtype="nvfp4", + )) + + flashinfer_output = flashinfer_experts( + hidden_states=a, + w1=w1_q, + w1_scale=w1_blockscale, + w2=w2_q, + w2_scale=w2_blockscale, + a1_scale=a1_gs, + a2_scale=a2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + + # Reference check: + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize) + + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) + + torch.testing.assert_close(torch_output, + flashinfer_output, + atol=1e-1, + rtol=1e-1) + + +if __name__ == "__main__": + test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half) diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py new file mode 100644 index 0000000000000..646e763194fd6 --- /dev/null +++ b/tests/kernels/moe/test_grouped_topk.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the MoE grouped topk kernel + +Run `pytest tests/kernels/moe/test_grouped_topk.py`. +""" +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk, + grouped_topk) +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") +@pytest.mark.parametrize("n_token", [1, 33, 64]) +@pytest.mark.parametrize("n_hidden", [1024, 2048]) +@pytest.mark.parametrize("n_expert", [16]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("num_expert_group", [8]) +@pytest.mark.parametrize("topk_group", [2]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int, + n_hidden: int, n_expert: int, topk: int, + renormalize: bool, num_expert_group: int, + topk_group: int, scoring_func: str, + routed_scaling_factor: float, dtype: torch.dtype): + current_platform.seed_everything(0) + hidden_states = torch.randn((n_token, n_hidden), + dtype=dtype, + device="cuda") + gating_output = torch.randn((n_token, n_expert), + dtype=dtype, + device="cuda") + e_score_correction_bias = torch.randn((n_expert, ), + dtype=torch.float32, + device="cuda") + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") + baseline_topk_weights, baseline_topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + test_topk_weights, test_topk_ids = fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + if renormalize: + torch.testing.assert_close(baseline_topk_weights, + test_topk_weights, + atol=2e-2, + rtol=0) + torch.testing.assert_close(baseline_topk_ids, + test_topk_ids, + atol=0, + rtol=0) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 6f2869c3a61d7..6112183be5475 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import textwrap +import traceback from itertools import product from typing import Optional @@ -10,41 +12,52 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, current_platform, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.layer import TritonExperts -from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from ...utils import multi_gpu_test from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, reference_moe_impl, run_modular_kernel) from .modular_kernel_tools.mk_objects import ( MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, - MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info) from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, parallel_launch_with_config) -# TODO (varun): These requirements are very strict and could be relaxed. -has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) +has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx() + or has_flashinfer_cutlass_fused_moe()) -meets_package_requirements = pytest.mark.skipif( - not has_all_packages, - reason="Requires deep_ep & deep_gemm & pplx packages", +meets_multi_gpu_requirements = pytest.mark.skipif( + not has_any_multi_gpu_package, + reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages", ) +def format_result(verbose, msg, ex=None): + if ex is not None: + x = str(ex) + newx = x.strip(" \n\t")[:16] + if len(newx) < len(x): + newx = newx + " ..." + + prefix = "E\t" + print(f"{textwrap.indent(traceback.format_exc(), prefix)}") + print(f"FAILED {msg} - {newx}\n") + elif verbose: + print(f"PASSED {msg}") + else: + print(".", end="") + + def rank_worker( pgi: ProcessGroupInfo, vllm_config: VllmConfig, cpu_group, config: Config, weights: WeightTensors, + verbose: bool, ): current_platform.seed_everything(pgi.rank) @@ -61,39 +74,64 @@ def rank_worker( TOPKs = config.topks assert isinstance(TOPKs, list) + exceptions = [] + count = 0 + for m, topk in product(Ms, TOPKs): - print(f"Running m={m}, topk={topk} ...") - # override m and topk - cfgx = copy.deepcopy(config) - cfgx.Ms = m - cfgx.topks = topk + try: + print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...") + count = count + 1 + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk - # inputs for rank - rank_tensors = RankTensors.make(cfgx, pgi) + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) - # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, - rank_tensors) + # modular kernel out + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + rank_tensors) - with set_current_vllm_config(vllm_config): - ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + with set_current_vllm_config(vllm_config): + ref_out = reference_moe_impl(cfgx, weights, rank_tensors) - torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) + if config.quant_dtype == "nvfp4": + atol = 1e-1 + rtol = 1e-1 + else: + atol = 3e-2 + rtol = 3e-2 + + torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol) + format_result(verbose, config.describe()) + except Exception as ex: + format_result(verbose, config.describe(), ex) + exceptions.append(ex) + + if len(exceptions) > 0: + raise RuntimeError( + f"{len(exceptions)} of {count} tests failed in child process, " + f"rank={pgi.rank}.") + else: + print(f"{count} of {count} tests passed in child process, " + f"rank={pgi.rank}.") -def run(config: Config): +def run(config: Config, verbose: bool): assert config.is_valid() - print(f"Testing config \n{config.describe()} ...") weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights) + env_dict, config, weights, verbose) Ms = [32, 64] -Ks = [7168] # hidden sizes +# hidden sizes, making this too large will cause fp4 tests to fail. +# Also needs to be a multiple of 1024 for deep_gemm. +Ks = [2048] Ns = [2048] TOPKs = [4, 1] Es = [32] @@ -103,19 +141,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16] def is_nyi_config(config: Config) -> bool: # We know these configs to be legitimate. but still fail. + info = expert_info(config.fused_experts_type) - if (config.fused_experts_type in [ - BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, - TritonExperts, TritonOrDeepGemmExperts - ]): + if info.needs_matching_quant: # The triton kernels expect both per-act-token-quant and # per-out-ch-quant or neither. unsupported_quant_config = ((config.is_per_act_token_quant + config.is_per_out_ch_quant) == 1) return unsupported_quant_config - # cutlass kernels dont support expert_maps yet. - return config.fused_experts_type == CutlassExpertsFp8 + return not info.supports_expert_map @pytest.mark.parametrize("k", Ks) @@ -128,13 +163,14 @@ def is_nyi_config(config: Config) -> bool: product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [2]) -@meets_package_requirements +@multi_gpu_test(num_gpus=2) +@meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, + quant_config: Optional[FusedMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int): + fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): config = Config( Ms=Ms, @@ -149,14 +185,15 @@ def test_modular_kernel_combinations_multigpu( fused_moe_chunk_size=fused_moe_chunk_size, world_size=world_size, ) + if not config.is_valid(): pytest.skip(f"Tests config {config} is not valid. Skipping ...") if is_nyi_config(config): pytest.skip(f"Tests config {config} is nyi. Skipping ...") - print(f"{config.describe()}") - run(config) + verbosity = pytestconfig.getoption('verbose') + run(config, verbosity > 0) @pytest.mark.parametrize("k", Ks) @@ -169,13 +206,12 @@ def test_modular_kernel_combinations_multigpu( product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [1]) -@meets_package_requirements def test_modular_kernel_combinations_singlegpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, + quant_config: Optional[FusedMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int): + fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): config = Config( Ms=Ms, K=k, @@ -196,7 +232,8 @@ def test_modular_kernel_combinations_singlegpu( if is_nyi_config(config): pytest.skip(f"Tests config {config} is nyi. Skipping ...") - run(config) + verbosity = pytestconfig.getoption('verbose') + run(config, verbosity > 0) if __name__ == '__main__': @@ -211,4 +248,4 @@ if __name__ == '__main__': args = parser.parse_args() config = make_config(args) - run(config) + run(config, True) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 1951eb0c61802..0ea9667914fd5 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -429,11 +429,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) - torch.cuda.empty_cache() vllm_moe.experts.w2_weight = Parameter(F.pad( vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) + torch.cuda.synchronize() torch.cuda.empty_cache() # Run forward passes for both MoE blocks diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 6ca01f9271bba..d71664d94b9c8 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, atol=0, rtol=0) # check mindice - torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) + # current kernel usage assumes deepgemm requires align_block_size + # when it's not provided then we don't compute m_indices (for cutlass) + if align_block_size is not None: + torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) + # check permuted_hidden_states, only valid token torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], permuted_hidden_states[valid_row_idx], diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 824b072a9f933..7bd1ffce58e96 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -4,15 +4,27 @@ import importlib import importlib.metadata from dataclasses import dataclass +from typing import Optional import pytest import torch from packaging import version +from vllm.platforms import current_platform + QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') +TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( +) and current_platform.is_device_capability(100) + +if TRTLLM_GEN_MXFP4_AVAILABLE: + from flashinfer import (fp4_quantize, mxfp8_quantize, + next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, shuffle_matrix_a, + shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + @dataclass class ModelCase: @@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) - assert output \ No newline at end of file + assert output + + +def swiglu(x, + alpha: float = 1.702, + beta: float = 1.0, + limit: Optional[float] = None): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu * (x_linear + beta) + + +fp4_lookup_table = [ + 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6 +] + + +def mxfp4_dequantize(x, scale): + assert x.dtype == torch.uint8 + x = x.view(torch.uint8).to(torch.int32) + x_unpacked = torch.zeros(*x.shape[:-1], + x.shape[-1] * 2, + dtype=torch.int32, + device=x.device) + x_unpacked[..., 0::2].copy_(x & 0xF) + x_unpacked[..., 1::2].copy_((x >> 4) & 0xF) + + x_float = torch.zeros(x_unpacked.shape, + dtype=torch.float32, + device=x.device) + for i, val in enumerate(fp4_lookup_table): + x_float[x_unpacked == i] = val + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def mxfp8_dequantize(x, scale): + assert x.dtype == torch.float8_e4m3fn + x_float = x.to(torch.float32) + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def reference_moe( + roouting_logits, + topk, + num_experts, + hidden_states, + w13, + bias13, + w2, + bias2, + alpha, + beta, + limit, + act_type, +): + # renormalize routing + experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + t = hidden_states.clone() + # MLP #1 + mlp1_weight = w13[expert_indices, ...] + mlp1_bias = bias13[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias + t = swiglu(t, alpha=alpha, beta=beta, limit=limit) + + if act_type == 'mxfp8': + t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16), + is_sf_swizzled_layout=False) + t = mxfp8_dequantize(t_quantized, t_scale) + # MLP #2 + mlp2_weight = w2[expert_indices, ...] + mlp2_bias = bias2[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + assert t.shape == hidden_states.shape + return t.to(torch.bfloat16) + + +def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def tg_mxfp4_moe( + router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13_weight, + w13_weight_scale, + w13_bias, + w2_weight, + w2_weight_scale, + w2_bias, + act_type, + alpha, + beta, + limit, +) -> torch.Tensor: + sf_block_size = 32 + assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts + and w13_weight.shape[1] == intermediate_size * 2 + and w13_weight.shape[2] == hidden_size // 2) + assert (w13_weight_scale.dim() == 3 + and w13_weight_scale.shape[0] == num_experts + and w13_weight_scale.shape[1] == intermediate_size * 2 + and w13_weight_scale.shape[2] == hidden_size // sf_block_size) + assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts + and w2_weight.shape[1] == hidden_size + and w2_weight.shape[2] == intermediate_size // 2) + assert (w2_weight_scale.dim() == 3 + and w2_weight_scale.shape[1] == hidden_size + and w2_weight_scale.shape[2] == intermediate_size // sf_block_size) + assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts + and w13_bias.shape[1] == intermediate_size * 2) + assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts + and w2_bias.shape[1] == hidden_size) + + # Swap w1 and w3 as the defenition of + # swiglu is different in the trtllm-gen + w13_weight_scale_ = w13_weight_scale.clone() + w13_weight_ = w13_weight.clone() + w13_bias_ = w13_bias.clone() + w13_weight[:, :intermediate_size, :].copy_( + w13_weight_[:, intermediate_size:, :]) + w13_weight[:, intermediate_size:, :].copy_( + w13_weight_[:, :intermediate_size, :]) + w13_weight_scale[:, :intermediate_size, :].copy_( + w13_weight_scale_[:, intermediate_size:, :]) + w13_weight_scale[:, intermediate_size:, :].copy_( + w13_weight_scale_[:, :intermediate_size, :]) + w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:]) + w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size]) + + # Interleave the weights and scaling factors for activation + w13_weight_interleaved = [] + w13_weight_scale_interleaved = [] + w13_bias_interleaved = [] + for i in range(num_experts): + w13_weight_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight[i].clone())) + w13_weight_scale_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())) + w13_bias_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, + 1))) + w13_weight = torch.stack(w13_weight_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2) + w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 32) + w13_bias = torch.stack(w13_bias_interleaved).reshape( + num_experts, 2 * intermediate_size) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_shuffled = [] + gemm2_scales_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(num_experts): + gemm1_weights_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) + gemm1_scales_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) + gemm2_scales_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) + + w13_weight = torch.stack(gemm1_weights_shuffled) + w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( + num_experts, 2 * intermediate_size, + hidden_size // sf_block_size).view(torch.float8_e4m3fn) + w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) + + w2_weight = torch.stack(gemm2_weights_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape( + num_experts, hidden_size, + intermediate_size // sf_block_size).view(torch.float8_e4m3fn) + w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) + + tg_result = trtllm_fp4_block_scale_moe( + routing_logits=router_logits.to(torch.bfloat16), + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale, + gemm1_bias=w13_bias, + gemm1_alpha=alpha, + gemm1_beta=beta, + gemm1_clamp_limit=limit, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale, + gemm2_bias=w2_bias, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=num_experts, + top_k=topk, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), + routing_method_type=1, # renormalize + do_finalize=True)[0] + return tg_result + + +def check_accuracy(a, b, atol, rtol, percent): + """Allow a mismatch percentage of 1 - percent.""" + if torch.any(torch.isnan(a)): + raise Exception("NaN in reference output") + if torch.any(torch.isnan(b)): + raise Exception("NaN in actual output") + if torch.any(torch.isinf(a)): + raise Exception("Inf in reference output") + if torch.any(torch.isinf(b)): + raise Exception("Inf in actual output") + assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" + + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception( + f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " + f"(threshold: {1-percent:.4f})") + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32, 128]) +@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +@pytest.mark.skipif( + not TRTLLM_GEN_MXFP4_AVAILABLE, + reason="nvidia gpu and compute capability sm100 is required for this test") +def test_trtllm_gen_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], + act_type: str, +): + seed = 42 + torch.manual_seed(seed) + hidden_states = torch.randn(num_tokens, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16) + w13 = (torch.randn(num_experts, + intermediate_size * 2, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16)) + w2 = (torch.randn(num_experts, + hidden_size, + intermediate_size, + device="cuda:0", + dtype=torch.bfloat16)) + bias13 = torch.randn(num_experts, intermediate_size * 2, + device="cuda:0") * 10 + bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10 + router_logits = torch.rand(num_tokens, num_experts, + dtype=torch.float32).cuda() + + w13, w13_scale = fp4_quantize(w13, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, hidden_size // 32) + w2, w2_scale = fp4_quantize(w2, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 32) + if act_type == 'mxfp8': + hidden_states, hidden_states_scale = mxfp8_quantize( + hidden_states, is_sf_swizzled_layout=False) + hidden_states_scale = hidden_states_scale.view( + torch.float8_e4m3fn).reshape(-1) + else: + hidden_states_scale = None + + # reference result + ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) + w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone()) + w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()) + bias13_ref = bias13 + bias2_ref = bias2 + if act_type == 'mxfp8': + hidden_states_ref = mxfp8_dequantize( + hidden_states, hidden_states_scale).to(torch.float32) + else: + hidden_states_ref = hidden_states.to(torch.float32) + # Process tokens in chunks of 32 to reduce memory usage + chunk_size = 32 + num_chunks = (num_tokens + chunk_size - 1) // chunk_size + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, num_tokens) + chunk_result = reference_moe( + router_logits[start_idx:end_idx].to(torch.float32), + topk, + num_experts, + hidden_states_ref[start_idx:end_idx], + w13_ref, + bias13_ref, + w2_ref, + bias2_ref, + alpha, + beta, + limit, + act_type, + ) + ref_result[start_idx:end_idx].copy_(chunk_result) + + # trtllm-gen result + if alpha is not None: + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts, ), limit, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts, ), beta, device=hidden_states.device) + tg_result = tg_mxfp4_moe(router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + bias13, + w2, + w2_scale, + bias2, + act_type, + alpha=alpha, + beta=beta, + limit=limit) + # relatively loose check since the mxfp4 quantization is less accurate + check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 3ff385360299b..30388ef9375d4 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -3,6 +3,7 @@ import pytest import torch +from tests.kernels.moe.utils import make_test_weights from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype) @@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 quant_blocksize = 16 - round_up = lambda x, y: (x + y - 1) // y * y - sf_w1_2n = round_up(2 * n, 128) - sf_w1_k = round_up(k // quant_blocksize, 4) - w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - sf_w2_k = round_up(k, 128) - sf_w2_n = round_up(n // quant_blocksize, 4) - w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), - device="cuda", - dtype=torch.float8_e4m3fn) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1_q = torch.empty((e, 2 * n, k // 2), - device="cuda", - dtype=torch.uint8) - w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) - w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) - w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) - - for expert in range(e): - w1_amax = torch.abs(w1).max().to(torch.float32) - w2_amax = torch.abs(w2).max().to(torch.float32) - w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax - w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax - - w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( - w1[expert], w1_gs[expert]) - - w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( - w2[expert], w2_gs[expert]) + (_, w1_q, w1_blockscale, + w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, # use quant_blocksize? + per_act_token_quant=False, + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, @@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + cutlass_output = cutlass_moe_fp4( a=a, a1_gscale=a1_gs, @@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, n=n, k=k, e=e, - device=a.device, ) # Reference check: a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) - _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, a_scale_interleaved, a_global_scale, @@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], w1_blockscale[idx], w1_gs[idx], - dtype=w1.dtype, - device=w1.device, + dtype=dtype, + device=w1_q.device, block_size=quant_blocksize) w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], w2_blockscale[idx], w2_gs[idx], - dtype=w2.dtype, - device=w2.device, + dtype=dtype, + device=w2_q.device, block_size=quant_blocksize) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index e4f4a393dfd56..9e78f4d6e4da0 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,13 +9,15 @@ import torch from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassBatchedExpertsFp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import cdiv +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch try: @@ -75,6 +77,7 @@ def pplx_cutlass_moe( assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape + intermediate_dim = w2.shape[2] num_experts = w1.shape[0] block_size = hidden_dim # TODO support more cases device = pgi.device @@ -123,12 +126,27 @@ def pplx_cutlass_moe( num_local_experts=num_local_experts, num_dispatchers=num_dispatchers) - experts = CutlassExpertsFp8(num_local_experts, - out_dtype, - per_act_token, - per_out_ch, - num_dispatchers=num_dispatchers, - use_batched_format=True) + ab_strides1 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + ab_strides2 = torch.full((num_local_experts, ), + intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides1 = torch.full((num_local_experts, ), + 2 * intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides2 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + + experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, + out_dtype, per_act_token, per_out_ch, + ab_strides1, ab_strides2, c_strides1, + c_strides2) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -230,6 +248,7 @@ def _pplx_moe( @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) @pytest.mark.parametrize("use_internode", [False]) +@multi_gpu_test(num_gpus=2) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index fbef6706beaf0..3f36d7ada2e94 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.platforms import current_platform from vllm.utils import round_up +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch requires_pplx = pytest.mark.skipif( @@ -452,6 +453,7 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.optional @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_prepare_finalize_slow( mnk: tuple[int, int, int], e: int, @@ -740,6 +742,7 @@ def _pplx_moe( @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.optional @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_moe_slow( mnk: tuple[int, int, int], e: int, @@ -770,7 +773,7 @@ def test_pplx_moe_slow( a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - _, w1, w1_s, _, w2, w2_s = make_test_weights( + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( e, n, k, @@ -836,7 +839,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, args = dict() if make_weights: - _, w1, w1_s, _, w2, w2_s = make_test_weights( + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( e, n, k, @@ -880,6 +883,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_prepare_finalize( world_dp_size: tuple[int, int], use_internode: bool, @@ -893,6 +897,7 @@ def test_pplx_prepare_finalize( @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_moe( world_dp_size: tuple[int, int], use_internode: bool, diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 673a0aa367948..5a0379dfb4475 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -24,7 +24,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): current_platform.seed_everything(seed) # Input tensor of shape (E, T, 2*H) - y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda") + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( low=0, high=T, @@ -74,7 +74,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): y_se = y_s[e] y_qe = y_q[e] - torch.testing.assert_close(y_se[:nt], ref_s[:nt]) + torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( y_qe[:nt].to(torch.float32), ref_q[:nt].to(torch.float32), diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index c33134981acc0..82960bd57345d 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import per_block_cast_to_int8 +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX) from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) @@ -169,28 +171,41 @@ def make_quantized_test_activations( def moe_quantize_weights( w: torch.Tensor, w_s: Optional[torch.Tensor], - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_token_quant: bool, block_shape: Optional[list[int]], -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert (quant_dtype == torch.float8_e4m3fn - or quant_dtype == torch.int8), "only fp8/int8 supported" +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8 + or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported" + + w_gs = None if block_shape is not None: assert not per_token_quant if quant_dtype == torch.int8: w, w_s = per_block_cast_to_int8(w, block_shape) - else: + elif quant_dtype == torch.float8_e4m3fn: w, w_s = per_block_cast_to_fp8(w, block_shape) + elif quant_dtype == "nvfp4": + raise RuntimeError("blocked quantization not supported for nvfp4") + else: + raise RuntimeError(f"Unsupported quant type {quant_dtype}") else: if quant_dtype == torch.int8: w, w_s = ops.scaled_int8_quant( w, w_s, use_per_token_if_dynamic=per_token_quant) - else: + elif quant_dtype == torch.float8_e4m3fn: w, w_s = ops.scaled_fp8_quant( w, w_s, use_per_token_if_dynamic=per_token_quant) + elif quant_dtype == "nvfp4": + assert not per_token_quant + w_amax = torch.abs(w).max().to(torch.float32) + w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax + w, w_s = ops.scaled_fp4_quant(w, w_gs) + else: + raise RuntimeError(f"Unsupported quant type {quant_dtype}") - return w, w_s + return w, w_s, w_gs def make_test_weight( @@ -198,21 +213,26 @@ def make_test_weight( rows: int, cols: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Optional[torch.dtype] = None, + quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 + w_gs = None if quant_dtype is not None: w_l = [None] * e w_s_l = [None] * e + w_gs_l = [None] * e for idx in range(e): - w_l[idx], w_s_l[idx] = moe_quantize_weights( + w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights( w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) w = torch.stack(w_l) w_s = torch.stack(w_s_l) + if e > 0 and w_gs_l[0] is not None: + w_gs = torch.stack(w_gs_l) if w_s.ndim == 2: assert w_s.shape[-1] == 1 w_s = w_s.view(-1, 1, 1) @@ -225,8 +245,9 @@ def make_test_weight( else: w = w_16 w_s = None + w_gs = None - return w_16, w, w_s + return w_16, w, w_s, w_gs def make_test_weights( @@ -234,14 +255,30 @@ def make_test_weights( n: int, k: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Optional[torch.dtype] = None, + quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, - torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]], + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]]: return ( - *make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, - per_act_token_quant), - *make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, + per_act_token_quant), + make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, + per_act_token_quant), ) + + +def per_token_cast_to_fp8( + x: torch.Tensor, + block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (block_size - (n % block_size)) % block_size + x = torch.nn.functional.pad(x, + (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, block_size) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) diff --git a/tests/kernels/quantization/test_aqlm.py b/tests/kernels/quantization/test_aqlm.py deleted file mode 100644 index 427db3e602921..0000000000000 --- a/tests/kernels/quantization/test_aqlm.py +++ /dev/null @@ -1,40 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops # noqa: F401 - - -def test_aqlm_dequant_opcheck(): - codes = torch.randint(-32768, - 32767, (22016, 512, 1), - device='cuda', - dtype=torch.int16) - codebooks = torch.rand((2, 65536, 1, 8), - device='cuda', - dtype=torch.float16) - codebook_partition_sizes = [11008, 11008] - - opcheck(torch.ops._C.aqlm_dequant, - (codes, codebooks, codebook_partition_sizes)) - - -def test_aqlm_gemm_opcheck(): - input = torch.rand((4, 4096), device='cuda', dtype=torch.float16) - codes = torch.randint(-32768, - 32767, (12288, 512, 1), - device='cuda', - dtype=torch.int16) - codebooks = torch.rand((3, 65536, 1, 8), - device='cuda', - dtype=torch.float16) - scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16) - codebook_partition_sizes = [4096, 4096, 4096] - bias = None - - opcheck(torch.ops._C.aqlm_gemm, - (input, codes, codebooks, scales, codebook_partition_sizes, None)) - opcheck(torch.ops._C.aqlm_gemm, - (input, codes, codebooks, scales, codebook_partition_sizes, bias)) diff --git a/tests/kernels/quantization/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py index 96797e85bd125..9354495642b28 100644 --- a/tests/kernels/quantization/test_awq_triton.py +++ b/tests/kernels/quantization/test_awq_triton.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the AWQ Triton kernel. -Run `pytest tests/kernels/test_awq_triton.py`. +Run `pytest tests/kernels/quantization/test_awq_triton.py`. """ import pytest import torch diff --git a/tests/kernels/quantization/test_cutlass_2of4_sparse.py b/tests/kernels/quantization/test_cutlass_2of4_sparse.py index 878f66647e19e..ae61b3b3a28a8 100644 --- a/tests/kernels/quantization/test_cutlass_2of4_sparse.py +++ b/tests/kernels/quantization/test_cutlass_2of4_sparse.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for sparse cutlass kernels -Run `pytest tests/kernels/test_semi_structured.py`. +Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`. """ import pytest diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8730eeaaa761c..65320509e173f 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for cutlass kernels -Run `pytest tests/kernels/test_cutlass.py`. +Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`. """ import random @@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, expert_offsets = torch.zeros((num_experts + 1), device=device, - dtype=torch.int32) + dtype=torch.int64) problem_sizes = torch.zeros((num_experts, 3), device=device, diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py new file mode 100644 index 0000000000000..f659408efe8c6 --- /dev/null +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the CUTLASS W4A8 kernel. + +Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`. +""" + +from dataclasses import dataclass +from typing import Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672), + (13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096), + (64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096), + (1024, 4096, 8192), (1024, 8192, 4096)] + +# TODO(czhu): get supported schedules from fn +SCHEDULES = [ + '128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1', + '128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1', + '128x256_1x1x1', '128x256_2x1x1' +] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: torch.Tensor + w_ch_s: torch.Tensor + w_tok_s: torch.Tensor + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + *( + TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=w_type, + output_type=o_type, + group_scale_type=torch.float8_e4m3fn, + channel_scale_type=torch.float32, + token_scale_type=torch.float32) + for w_type in [scalar_types.int4] + # TODO(czhu): fp16 out type + for o_type in [torch.bfloat16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +# For testing quantized linear kernels +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, + max=finfo.max).to(dtype=torch.float8_e4m3fn) + + +def cutlass_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights(w, + wtype, + group_size=group_size, + zero_points=zero_points) + + # since scales are cast to fp8, we need to compute w_ref this way + w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to( + torch.float32).repeat_interleave(group_size, dim=0)).to(atype) + + # bit mask prevents sign extending int4 when packing + w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q) + w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype)) + + return w_ref, w_q_packed, w_s_packed, w_zp + + +def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> Tensors: + m, n, k = shape + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = to_fp8(torch.randn((m, k), device="cuda")) + w = to_fp8(torch.randn((k, n), device="cuda")) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + False) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + # for the practical use case we need per-tok scales for fp8 activations + w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type) + # weights are already per-group quantized, use placeholder here + w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +def mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + # CUTLASS upstream uses fp8 with fastaccum as reference + # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406 + output_ref = torch._scaled_mm( + tensors.a_ref.to(types.act_type), + tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major + tensors.w_tok_s.unsqueeze(1), + tensors.w_ch_s.unsqueeze(0), + out_dtype=types.output_type, + use_fast_accum=True) + + output = ops.cutlass_w4a8_mm( + a=tensors.a, + b_q=tensors.w_q, + b_group_scales=tensors.w_g_s, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + ) + + print(output) + print(output_ref) + + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=1e-3, + atol=1e-3) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +@pytest.mark.parametrize("schedule", SCHEDULES) +def test_cutlass_w4a8(shape, types: TypeConfig, schedule): + group_sizes = [128] + for group_size in group_sizes: + tensors = create_test_tensors(shape, types, group_size) + mm_test_helper(types, tensors, group_size, schedule) + + +# Test to make sure cuda graphs work +class W4A8Layer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.cutlass_w4a8_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +def test_w4a8_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = to_fp8(torch.randn((m, k), device="cuda")) + b = to_fp8(torch.randn((k, n), device="cuda")) + + wtype = scalar_types.int4 + stype = torch.float8_e4m3fn + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points) + + w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32) + w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32) + + # Construct a trivial model with a single layer that calls the kernel + model = W4A8Layer( + b_q=w_q_packed, + b_group_scales=w_s, + b_group_size=group_size, + b_channel_scales=w_ch_s, + a_token_scales=w_tok_s, + ) + + output_ref = torch._scaled_mm( + a, + w_ref.to(a.dtype).t().contiguous().t(), # col major + w_tok_s.unsqueeze(1), + w_ch_s.unsqueeze(0), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + + output.zero_() + g.replay() + + torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) diff --git a/tests/kernels/quantization/test_flashinfer_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_scaled_mm.py new file mode 100644 index 0000000000000..9f669c6df8bd5 --- /dev/null +++ b/tests/kernels/quantization/test_flashinfer_scaled_mm.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm + +if not current_platform.has_device_capability(100): + pytest.skip( + reason= + "Flashinfer FP8 gemms requires compute capability of 10.0 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("autotune", [False, True]) +@torch.inference_mode() +def test_flashinfer_fp8_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + use_bias: bool, + seed: int, + device: str, + autotune: bool, +) -> None: + current_platform.seed_everything(seed) + m, n, k = shape + a = torch.randn((m, k), dtype=dtype, device=device) + b = torch.randn((n, k), dtype=dtype, device=device) / k + + a_fp8, a_scale = ops.scaled_fp8_quant(a) + b_fp8, b_scale = ops.scaled_fp8_quant(b) + + expected_out = torch.mm( + a_scale * a_fp8.to(dtype=torch.float32), + b_scale * b_fp8.to(dtype=torch.float32).t(), + ).to(dtype=dtype) + + if use_bias: + bias = torch.randn((n, ), dtype=dtype, device=device) + expected_out = expected_out + bias + else: + bias = None + + import flashinfer + + with flashinfer.autotune(autotune): + out = flashinfer_scaled_fp8_mm( + a_fp8, + b_fp8.t(), + a_scale, + b_scale, + dtype, + bias=bias, + ) + + torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index a842d2f1cbe8d..50584f3f82d4c 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the machete kernel. -Run `pytest tests/kernels/test_machete_mm.py`. +Run `pytest tests/kernels/quantization/test_machete_mm.py`. """ import math @@ -95,23 +95,23 @@ TEST_TYPES = [ token_scale_type=None) for w_type in [scalar_types.uint4, scalar_types.uint8] for a_type in [torch.float16, torch.bfloat16]), - # QQQ style - *(TypeConfig(act_type=torch.int8, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), - *(TypeConfig(act_type=torch.float8_e4m3fn, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), + # # QQQ style + # *(TypeConfig(act_type=torch.int8, + # weight_type=scalar_types.uint4b8, + # output_type=torch.float16, + # group_scale_type=group_scale_type, + # group_zero_type=None, + # channel_scale_type=torch.float, + # token_scale_type=torch.float) + # for group_scale_type in [None, torch.float16]), + # *(TypeConfig(act_type=torch.float8_e4m3fn, + # weight_type=scalar_types.uint4b8, + # output_type=torch.float16, + # group_scale_type=group_scale_type, + # group_zero_type=None, + # channel_scale_type=torch.float, + # token_scale_type=torch.float) + # for group_scale_type in [None, torch.float16]), ] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index cea7700ac3293..0be020085bfa4 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the marlin kernel. -Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. +Run `pytest tests/kernels/quantization/test_marlin_gemm.py`. """ import pytest import torch @@ -13,11 +13,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) -from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, query_marlin_supported_quant_types) @@ -31,8 +27,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_weights) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 - marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) from vllm.scalar_type import scalar_types @@ -449,68 +443,6 @@ def test_hqq_marlin_gemm( assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("qqq"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS) -@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_marlin_qqq_gemm( - k_chunk, - n_chunk, - num_bits, - group_size, - mnk_factors, -): - int8_traits = torch.iinfo(torch.int8) - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - # Quantize activations - s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to( - torch.float) - q_a = (a_input / s_a).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - - # Quantize weights - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \ - marlin_qqq_quantize(b_weight, num_bits, group_size) - - workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_MAX_PARALLEL) - - opcheck(torch.ops._C.marlin_qqq_gemm, - (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel, - marlin_qqq_s_group, workspace.scratch, a_input.shape[0], - b_weight.shape[1], a_input.shape[1])) - - output = ops.marlin_qqq_gemm( - q_a, - marlin_qqq_q_w, - s_a, - marlin_qqq_s_channel, - marlin_qqq_s_group, - workspace.scratch, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - ) - output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - def test_marlin_gemm_subset_input(): quant_type = scalar_types.uint4b8 group_size = 128 @@ -602,18 +534,3 @@ def test_marlin_gemm_with_bias(size_m): max_diff = compute_max_diff(output, output_ref) assert max_diff < 0.04 - - -def test_marlin_gemm_opcheck(): - size_m = 2048 - size_n = 4096 - size_k = 4096 - a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16) - w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32) - s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16) - wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL).scratch - x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) - y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) - torch.testing.assert_close(x, y) - opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k)) diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 24245663fb1d6..d8cfb5710dbad 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the triton_scaled_mm kernel -Run `pytest tests/kernels/test_triton_scaled_mm.py`. +Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`. """ import importlib from typing import Optional diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index f76bd192460c9..39753c0cc15b9 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -9,12 +9,17 @@ import pytest import torch from packaging import version -from vllm import SamplingParams +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config) +from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionMetadataBuilder) -from ..models.utils import check_embeddings_close +from ..models.utils import check_embeddings_close, check_logprobs_close TORCH_VERSION = version.parse(torch.__version__) MINIMUM_TORCH_VERSION = version.parse("2.7.0") +DIRECT_BUILD_VERSION = version.parse("2.9.dev0") def set_seed(seed): @@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): """Test that FlexAttention produces the same outputs as the default backend. This test compares the outputs from the FlexAttention backend with - the default backend, ensuring they are identical when using the same seed. + the default backend, ensuring they are similar when using the same seed. """ model_name = "Qwen/Qwen2.5-1.5B-Instruct" seed = 42 max_tokens = 24 + num_logprobs = 5 prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", ] - sampling_params = SamplingParams(temperature=0.0, - top_p=1.0, - seed=seed, - max_tokens=max_tokens) - # Run with flex attention with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): tensor_parallel_size=1, num_gpu_blocks_override=128, enforce_eager=True) as llm_flex: - output_flex = llm_flex.generate(prompts, sampling_params) + output_flex = llm_flex.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) # Run with default backend with monkeypatch.context() as m: @@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): runner="generate", tensor_parallel_size=1, num_gpu_blocks_override=128, - enforce_eager=True) as llm_default: - output_default = llm_default.generate(prompts, sampling_params) + enforce_eager=True, + gpu_memory_utilization=0.85) as llm_default: + output_default = llm_default.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) - # Compare outputs from both backends - for i, (flex_result, - default_result) in enumerate(zip(output_flex, output_default)): - prompt = prompts[i] - flex_text = flex_result[1][0] - default_text = default_result[1][0] - - assert flex_text == default_text, ( - f"FlexAttention output doesn't match default for: {prompt!r}\n" - f"FlexAttention: {flex_text!r}\n" - f"Default: {default_text!r}") + check_logprobs_close( + outputs_0_lst=output_flex, + outputs_1_lst=output_default, + name_0="flex", + name_1="default", + ) @pytest.mark.skipif( @@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ) +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION, + reason="CUDA not available or PyTorch version < 2.7", +) +def test_block_mask_direct_vs_slow_path(): + """Test that direct path block mask is a superset of slow path. + + The direct path may include extra blocks for performance (over-estimation), + but must include all blocks that the slow path determines are necessary. + """ + device = torch.device("cuda") + + vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B", + block_size=16, + max_model_len=1024) + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # Use a mixed batch that will create groups spanning multiple sequences + batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256], + query_lens=[33, 5, 32, 64], + name="test_mixed_batch") + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, + device) + + metadata_direct = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + builder.direct_build = False + metadata_slow = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + + assert metadata_direct.block_mask is not None + assert metadata_slow.block_mask is not None + + # Extract block indices for comparison, B, H are the same + direct_indices = metadata_direct.block_mask.kv_indices[0, 0] + slow_indices = metadata_slow.block_mask.kv_indices[0, 0] + direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0] + slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0] + + # main test: every block needed by slow path must be in direct path + num_groups = direct_num.shape[0] + all_contained = True + missing_details = [] + + for group_idx in range(num_groups): + direct_blocks = set( + direct_indices[group_idx, :direct_num[group_idx]].tolist()) + slow_blocks = set( + slow_indices[group_idx, :slow_num[group_idx]].tolist()) + + missing_blocks = slow_blocks - direct_blocks + if missing_blocks: + all_contained = False + missing_details.append( + f"Group {group_idx}: missing {sorted(missing_blocks)}") + + assert all_contained, ( + "Direct path is missing blocks required by slow path:\n" + + "\n".join(missing_details)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py new file mode 100644 index 0000000000000..17692384ac9a9 --- /dev/null +++ b/tests/kernels/test_onednn.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Integration tests for FlexAttention backend vs default backend""" + +from typing import Optional + +import pytest +import torch + +from tests.kernels.utils import to_int8 +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +if not current_platform.is_cpu(): + pytest.skip("skipping CPU-only tests", allow_module_level=True) + +NK_FACTORS = [ + (256, 128), + (4096, 4096), + (16384, 4096), + (1023, 491), + (1001, 15), +] +M_FACTORS = [ + (16, 1, 32, 128, 64), + (1, 17, 1, 31, 17), +] +CACHE_SIZES = [2] +DTYPE = [torch.bfloat16] + + +def rand_int8(shape: tuple, device: str = "cpu"): + return to_int8(torch.rand(shape, device=device) * 255 - 128) + + +def ref_int8_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + azp: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + output_type: torch.dtype, +): + if azp is not None: + a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32) + output = torch.mm((scale_a * a.to(dtype=torch.float32)), + (scale_b * b.to(dtype=torch.float32))) + if bias is not None: + output += bias.float() + + return output.to(dtype=output_type) + + +def onednn_int8_gemm_test_helper(primitive_cache_size: int, + m: int, + n: int, + k: int, + per_tensor_a_quant: bool, + per_tensor_b_quant: bool, + use_azp: bool, + use_bias: bool, + out_dtype: torch.dtype = torch.bfloat16, + device: str = "cpu"): + # Test for a oneDNN kernel with per-tensor / per-token activation + # quantization and per-tensor / per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) + + a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1) + b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n) + + scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) + scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + + if use_azp: + azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5 + azp = (azp / scale_a).round().to(dtype=torch.int32) + azp_adj = scale_b * b.sum(dim=0, keepdim=True, dtype=torch.float32) + else: + azp = None + azp_adj = None + + if use_bias: + bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 + else: + bias = None + + handler = ops.create_onednn_scaled_mm( + b, + scale_b, + out_dtype, + not per_tensor_a_quant, + use_azp, + primitive_cache_size, + ) + + out = torch.zeros((m, n), dtype=out_dtype) + ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, bias) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, bias, out_dtype) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + if use_bias: + # To test runtime bias setting + out = torch.zeros((m, n), dtype=out_dtype) + ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, + out_dtype) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + +@pytest.mark.parametrize("n,k", NK_FACTORS) +@pytest.mark.parametrize("m_list", M_FACTORS) +@pytest.mark.parametrize("per_tensor_a_scale", [True, False]) +@pytest.mark.parametrize("per_tensor_b_scale", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_azp", [True, False]) +@pytest.mark.parametrize("output_type", DTYPE) +@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES) +def test_onednn_int8_scaled_gemm( + n: int, + k: int, + m_list: tuple[int], + per_tensor_a_scale: bool, + per_tensor_b_scale: bool, + use_bias: bool, + use_azp: bool, + output_type: torch.dtype, + primitive_cache_size: int, +): + for m in m_list: + onednn_int8_gemm_test_helper( + primitive_cache_size=primitive_cache_size, + m=m, + n=n, + k=k, + per_tensor_a_quant=per_tensor_a_scale, + per_tensor_b_quant=per_tensor_b_scale, + use_bias=use_bias, + use_azp=use_azp, + out_dtype=output_type, + ) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 909b73933139d..3475993ff8f07 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -3,15 +3,13 @@ import tempfile from collections import OrderedDict -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest import torch import torch.nn as nn from huggingface_hub import snapshot_download -import vllm -from vllm.config import LoRAConfig from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) @@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.platforms import current_platform @@ -104,6 +101,7 @@ def dummy_model() -> nn.Module: ])) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} + model.unpadded_vocab_size = 32000 return model @@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module: ], } model.embedding_modules = {"lm_head": "lm_head"} + model.unpadded_vocab_size = 32000 + return model @@ -216,34 +216,6 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") -@pytest.fixture(scope="session") -def phi2_lora_files(): - return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") - - -@pytest.fixture -def llama_2_7b_engine_extra_embeddings(): - cleanup_dist_env_and_memory(shutdown_ray=True) - get_model_old = get_model - - def get_model_patched(**kwargs): - kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4, - max_lora_rank=8) - return get_model_old(**kwargs) - - with patch("vllm.worker.model_runner.get_model", get_model_patched): - engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) - yield engine.llm_engine - del engine - cleanup_dist_env_and_memory(shutdown_ray=True) - - -@pytest.fixture -def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): - yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. - model_runner.model) - - @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index d7b019509fa3e..44755c603f281 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -5,7 +5,6 @@ import time import pytest -import vllm.envs as env from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) @@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files): # Run with warmup add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests] add_lora_results = await asyncio.gather(*add_lora_tasks) - if env.VLLM_USE_V1: - # Test that all all_lora calls are successful. - assert all(add_lora_results) - else: - # No way to check V0 engine results as the calls just return None. - pass + + # Test that all all_lora calls are successful. + assert all(add_lora_results) + time_with_add_lora = await requests_processing_time( llm, warmup_run_requests) diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py deleted file mode 100644 index 774ebb9db2106..0000000000000 --- a/tests/lora/test_baichuan.py +++ /dev/null @@ -1,112 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -import vllm -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_PATH = "baichuan-inc/Baichuan-7B" - -PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format(query="How many singers do we have?"), - PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 - ), - ] - print(prompts) - sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_baichuan_lora(baichuan_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True) - - expected_lora_output = [ - "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501 - "SELECT name , country , age FROM singer ORDER BY age ASC", - ] - - output1 = do_sample(llm, baichuan_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i] == expected_lora_output[i] - output2 = do_sample(llm, baichuan_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i] == expected_lora_output[i] - - -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_baichuan_tensor_parallel_equality(baichuan_lora_files, - num_gpus_available, fully_sharded): - if num_gpus_available < 4: - pytest.skip(f"Not enough GPUs for tensor parallelism {4}") - - llm_tp1 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) - - del llm_tp1 - cleanup_dist_env_and_memory() - - llm_tp2 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=2, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) - - del llm_tp2 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp2 - - llm_tp4 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) - - del llm_tp4 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp4 diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 92db023babc28..6e2dda464d8eb 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -243,7 +243,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -347,7 +347,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: @torch.inference_mode() # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -486,7 +486,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) @@ -620,12 +620,15 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_replicated(dist_init, num_loras, device, stage, - bias_enabled) -> None: +def test_linear_replicated( + dist_init, + num_loras, + device, + stage, +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -634,10 +637,11 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16, + ) def create_random_linear_replicated_layer(): @@ -651,10 +655,6 @@ def test_linear_replicated(dist_init, num_loras, device, stage, lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -734,14 +734,13 @@ def test_linear_replicated(dist_init, num_loras, device, stage, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -750,11 +749,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_random_linear_parallel_layer(): if orientation == "row": @@ -777,10 +777,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -860,14 +857,13 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -876,11 +872,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_column_parallel_packed_layer(): if repeats == 2: @@ -924,10 +921,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, model_config=FakeConfig()) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index b1ad1fdd06064..06196cc697cec 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files): enable_lora=True, # also test odd max_num_seqs max_num_seqs=13, - max_loras=4, - enable_chunked_prefill=True) + max_loras=4) generate_and_test(llm, sql_lora_files) @@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files): max_num_seqs=16, max_loras=4, tensor_parallel_size=4, - enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) @@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): max_loras=4, tensor_parallel_size=4, fully_sharded_loras=True, - enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 8f8a27006cf67..c9ab32edc7f32 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.platforms import current_platform +from .utils import create_peft_lora + EMBEDDING_MODULES = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", @@ -35,17 +37,6 @@ DEVICES = ([ DEFAULT_DTYPE = torch.get_default_dtype() -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - Some tests depend on V0 internals. Since both V0 and V1 use the same - LoRAModelManager it is okay to just test V0. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( @@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): max_loras=2, lora_dtype=DEFAULT_DTYPE), device=device) - assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity @@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) -def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, + tmp_path): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) worker_adapter_manager = LRUCacheWorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, device, - EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + 4, 2, + dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size, + lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + worker_adapter_manager.create_lora_manager(dummy_model) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files) ], mapping) assert worker_adapter_manager.device == device @@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, @pytest.mark.parametrize("device", DEVICES) -def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, + tmp_path): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = WorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - + 4, 2, dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager(dummy_model_gate_up) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model_gate_up, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 @@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files) ], mapping) assert worker_adapter_manager.device == device diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 0ea07793311cb..03e5d8d5d6728 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): max_loras=4, distributed_executor_backend="ray", tensor_parallel_size=tp_size, - enable_chunked_prefill=True, ) expected_lora_output = [ diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py deleted file mode 100644 index 3090941e63679..0000000000000 --- a/tests/lora/test_phi.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import vllm -from vllm.lora.request import LoRARequest - -MODEL_PATH = "microsoft/phi-2" - -PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format( - sql_prompt= - "Which catalog publisher has published the most catalogs?", - context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), - PROMPT_TEMPLATE.format( - sql_prompt= - "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 - context= - "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - sql_prompt= - "How many marine species are found in the Southern Ocean?", # noqa: E501 - context= - "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 - ), - ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=64, - stop="### End") - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_phi2_lora(phi2_lora_files): - # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, - # Otherwise, the lora-test will fail due to CUDA OOM. - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=2, - enforce_eager=True, - enable_chunked_prefill=True) - - expected_lora_output = [ - "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 - "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 - "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 - ] - - output1 = do_sample(llm, phi2_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i].startswith(expected_lora_output[i]) - output2 = do_sample(llm, phi2_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i].startswith(expected_lora_output[i]) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index bd0aea67b9702..a836ff94ba3ed 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -4,17 +4,14 @@ import os import random import tempfile -from typing import Union from unittest.mock import patch -import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.v1.worker.gpu_worker import Worker as V1Worker -from vllm.worker.worker import Worker +from vllm.v1.worker.gpu_worker import Worker NUM_LORAS = 16 @@ -22,18 +19,11 @@ NUM_LORAS = 16 @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - def set_active_loras(worker: Union[Worker, V1Worker], - lora_requests: list[LoRARequest]): + def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_mapping = LoRAMapping([], []) - if isinstance(worker, Worker): - # v0 case - worker.model_runner.set_active_loras(lora_requests, lora_mapping) - else: - # v1 case - worker.model_runner.lora_manager.set_active_adapters( - lora_requests, lora_mapping) - worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker + worker.model_runner.lora_manager.set_active_adapters( + lora_requests, lora_mapping) vllm_config = VllmConfig( model_config=ModelConfig( @@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files): max_cpu_loras=NUM_LORAS, max_loras=NUM_LORAS), ) - worker = worker_cls( + worker = Worker( vllm_config=vllm_config, local_rank=0, rank=0, diff --git a/tests/lora/utils.py b/tests/lora/utils.py index cc1b0d81955bc..7cda90787b6f1 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os from dataclasses import dataclass from typing import Optional, Union import torch +from safetensors.torch import save_file from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights @@ -340,3 +343,76 @@ def generate_data_for_nslices( seq_len_tensor, indices, ) + + +def create_peft_lora( + model: torch.nn.Module, + save_dir: str, + target_modules: list[str], + rank: int = 8, + alpha: int = 16, + dropout: float = 0.1, + lora_dtype: torch.dtype = torch.float16, +) -> dict[str, torch.Tensor]: + lora_weights = {} + adapter_config = { + "peft_type": "LORA", + "auto_mapping": None, + "base_model_name_or_path": "dummy_model", + "revision": None, + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": rank, + "lora_alpha": alpha, + "lora_dropout": dropout, + "fan_in_fan_out": False, + "bias": "none", + "modules_to_save": None, + "init_lora_weights": True, + "layers_to_transform": None, + "layers_pattern": None, + "target_modules": target_modules, + "exclude_modules": None, + "use_rslora": False, + "use_dora": False, + "loftq_config": None, + } + + for module_name in target_modules: + + module = model + for attr in module_name.split("."): + module = getattr(module, attr) + + if hasattr(module, "input_size") and hasattr(module, "output_size"): + + in_features = module.input_size + out_features = module.output_size + + elif hasattr(module, "embedding_dim") and hasattr( + module, "num_embeddings"): + # ParallelLMHead + in_features = module.embedding_dim + out_features = module.num_embeddings + else: + raise ValueError( + f"Unable to determine dimensions for module {module_name}") + + lora_A = torch.randn(rank, in_features, dtype=lora_dtype) + + torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5) + + lora_B = torch.zeros(out_features, rank, dtype=lora_dtype) + + # PEFT style + lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A + lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B + + config_path = os.path.join(save_dir, "adapter_config.json") + with open(config_path, "w", encoding="utf-8") as f: + json.dump(adapter_config, f, indent=2, ensure_ascii=False) + + weights_path = os.path.join(save_dir, "adapter_model.safetensors") + save_file(lora_weights, weights_path) + + return lora_weights diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index aee0a50336c09..7e7cc893ec8aa 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -31,6 +31,7 @@ HYBRID_MODELS = [ "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", + "LiquidAI/LFM2-1.2B", ] HF_UNSUPPORTED_MODELS = [ @@ -52,18 +53,21 @@ V1_SUPPORTED_MODELS = [ "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", + "LiquidAI/LFM2-1.2B", +] + +FULL_CUDA_GRAPH_MODELS = [ + "ai21labs/Jamba-tiny-dev", + "Zyphra/Zamba2-1.2B-instruct", +] + +V0_UNSUPPORTED_MODELS = [ + "LiquidAI/LFM2-1.2B", ] # Avoid OOM MAX_NUM_SEQS = 4 -# Once we add support for FCG in Mamba1, this list will be removed and tests -# all test cases will use enforce_eager=False -ENFORCE_EAGER_MODELS_V1 = [ - "state-spaces/mamba-130m-hf", - "ai21labs/Jamba-tiny-dev", -] - @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -96,31 +100,25 @@ def test_models( else: hf_outputs = None - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None if model in V1_SUPPORTED_MODELS: - enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - - if model in ENFORCE_EAGER_MODELS_V1: - enforce_eager = True - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - enforce_eager=enforce_eager, enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: vllm_v1_outputs = None - if hf_outputs is not None: + if hf_outputs is not None and vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -130,6 +128,7 @@ def test_models( if model in V1_SUPPORTED_MODELS: ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + assert ref_outputs is not None check_logprobs_close( outputs_0_lst=ref_outputs, outputs_1_lst=vllm_v1_outputs, @@ -148,6 +147,9 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: + if model in V0_UNSUPPORTED_MODELS: + pytest.skip( + f"Unsupported V0 Engine. Skipping `test_batching` on {model}.") try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) @@ -373,7 +375,7 @@ def test_distributed_correctness( ) -@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_full_cuda_graph( @@ -400,9 +402,12 @@ def test_full_cuda_graph( else: hf_outputs = None - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -416,7 +421,7 @@ def test_full_cuda_graph( vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - if hf_outputs is not None: + if hf_outputs is not None and vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -425,6 +430,7 @@ def test_full_cuda_graph( ) ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + assert ref_outputs is not None check_logprobs_close( outputs_0_lst=ref_outputs, outputs_1_lst=vllm_v1_outputs, diff --git a/tests/models/language/generation/test_mbart.py b/tests/models/language/generation/test_mbart.py new file mode 100644 index 0000000000000..854a72713943b --- /dev/null +++ b/tests/models/language/generation/test_mbart.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.sequence import SampleLogprobs + +from ....conftest import DecoderPromptType, HfRunner, VllmRunner +from ...utils import check_logprobs_close + + +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + hf_output_str = output_str + "</s>" + return output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + prompts: list[dict[str, str]], + decoder_prompt_type: DecoderPromptType, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +) -> None: + ''' + Test the vLLM mBART model by validating it against HuggingFace (HF). + (Docstring content is omitted for brevity) + ''' + + vllm_prompts = prompts + if decoder_prompt_type == DecoderPromptType.NONE: + vllm_prompts = [{ + "encoder_prompt": p['encoder_prompt'], + "decoder_prompt": "" + } for p in prompts] + + vllm_kwargs = { + "hf_overrides": { + "architectures": ["MBartForConditionalGeneration"] + } + } + + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + **vllm_kwargs) as vllm_model: # type: ignore + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + vllm_prompts, max_tokens, num_logprobs) + + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_kwargs["decoder_start_token_id"] = ( + hf_model.tokenizer.lang_code_to_id["ro_RO"]) + + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, # HF runner still uses the original prompts + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + hf_skip_tokens = 0 + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) + + +@pytest.mark.parametrize( + "model", + [pytest.param("facebook/mbart-large-en-ro")], +) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, + dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: + + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index d21987571cbaa..17a55d916b1ff 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -14,6 +14,7 @@ from ....utils import RemoteOpenAIServer MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MAX_MODEL_LEN = 4000 +ATOL = 0.002 def _arr(arr): @@ -97,16 +98,16 @@ def get_test_data(): def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]): cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) - assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=0.001) + assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=ATOL) cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1]) - assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=0.001) + assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=ATOL) cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0]) - assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=0.001) + assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=ATOL) cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1]) - assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=0.001) + assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=ATOL) def test_gritlm_offline_embedding(vllm_runner): diff --git a/tests/models/language/pooling/test_multilabel_classification_support.py b/tests/models/language/pooling/test_multilabel_classification_support.py new file mode 100644 index 0000000000000..45366f2094144 --- /dev/null +++ b/tests/models/language/pooling/test_multilabel_classification_support.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForSequenceClassification + + +@pytest.mark.parametrize( + "model", + ["Rami/multi-label-class-classification-on-github-issues"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling/test_st_projector.py new file mode 100644 index 0000000000000..51ddbcc5ab249 --- /dev/null +++ b/tests/models/language/pooling/test_st_projector.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from .mteb_utils import mteb_test_embed_models + +# ST models with projector (Dense) layers +ST_PROJECTOR_MODELS = [ + CLSPoolingEmbedModelInfo( + "TencentBAC/Conan-embedding-v1", + architecture="BertModel", + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + + mteb_test_embed_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 2919bdbe91bbd..96208f8eda628 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -621,6 +621,23 @@ VLM_TEST_SETTINGS = { hf_model_kwargs={"llm_attn_implementation": "sdpa"}, patch_hf_runner=model_utils.ovis_patch_hf_runner, ), + "ovis2_5": VLMTestInfo( + models=["AIDC-AI/Ovis2.5-2B"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + video_idx_to_prompt=lambda idx: "<video>\n", + max_model_len=4096, + max_num_seqs=2, + dtype="half", + num_logprobs=10, + patch_hf_runner=model_utils.ovis2_5_patch_hf_runner, + hf_model_kwargs={"revision": "refs/pr/5"}, + ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py index b413c4d6b3667..1c32cc6d71c04 100644 --- a/tests/models/multimodal/generation/test_mllama.py +++ b/tests/models/multimodal/generation/test_mllama.py @@ -5,6 +5,7 @@ from typing import Optional, overload import pytest import torch +from packaging.version import Version from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer from transformers import __version__ as TRANSFORMERS_VERSION @@ -287,8 +288,8 @@ def clear_cache(): @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.skipif( - TRANSFORMERS_VERSION == "4.55.0", - reason="Transformers v4.55.0 has a regression issue on mllama, " + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " "see: https://github.com/huggingface/transformers/pull/40083") def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, model, sizes, dtype, max_tokens, @@ -319,8 +320,8 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.skipif( - TRANSFORMERS_VERSION == "4.55.0", - reason="Transformers v4.55.0 has a regression issue on mllama, " + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " "see: https://github.com/huggingface/transformers/pull/40083") def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, model, dtype, max_tokens, num_logprobs, @@ -372,8 +373,8 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.skipif( - TRANSFORMERS_VERSION == "4.55.0", - reason="Transformers v4.55.0 has a regression issue on mllama, " + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " "see: https://github.com/huggingface/transformers/pull/40083") def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, dtype, max_tokens, num_logprobs, @@ -416,8 +417,8 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.skipif( - TRANSFORMERS_VERSION == "4.55.0", - reason="Transformers v4.55.0 has a regression issue on mllama, " + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " "see: https://github.com/huggingface/transformers/pull/40083") def test_models_distributed( hf_runner, diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 5e8dac6bce96a..8b7d051218f14 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -10,6 +10,7 @@ from typing import Optional, Union import numpy as np import numpy.typing as npt +import PIL.Image import pytest import regex as re import torch @@ -19,7 +20,6 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature, from transformers.video_utils import VideoMetadata from vllm.sequence import SampleLogprobs -from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets @@ -343,7 +343,6 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for GLM4V.""" hf_processor = hf_model.processor - patch_padding_side(hf_processor) def processor(*args, text="", images=None, **kwargs): if images is None: @@ -812,6 +811,63 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model +def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for Ovis2.""" + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.llm.get_output_embeddings() + + def processor(*args, text="", images=None, videos=None, **kwargs): + if images is None: + images = [] + else: + images = [images] if isinstance(images, Image) else images + if videos is None: + videos = [] + else: + videos = [videos] if isinstance(videos, np.ndarray) else videos + videos = [[PIL.Image.fromarray(frame) for frame in vid] + for vid in videos] + + prompt_start_and_end = { + "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), + "llama": + ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"), + } + for start, end in prompt_start_and_end.values(): + if start in text and end in text: + text = text.split(start)[1].split(end)[0] + break + + images_message = [{"type": "image", "image": img} for img in images] + videos_message = [{"type": "video", "video": vid} for vid in videos] + + messages = [{ + "role": + "user", + "content": [ + *images_message, + *videos_message, + { + "type": "text", + "text": text + }, + ], + }] + + input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs( + messages=messages, enable_thinking=True) + inputs = { + "inputs": input_ids, + "pixel_values": pixel_values, + "grid_thws": grid_thws, + } + return BatchFeature(data=inputs, tensor_type="pt") + + hf_model.processor = processor + return hf_model + + def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner for Qwen2.5-Omni.""" thinker = hf_model.model.thinker diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 906966ddd0649..3ff4360b83345 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -14,8 +14,9 @@ from PIL import Image from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache +from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, cached_tokenizer_from_config, encode_tokens) @@ -63,6 +64,8 @@ def _test_processing_correctness( revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + # Ensure that the cache can fit all of the data + mm_processor_cache_gb=2048, ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) @@ -71,8 +74,7 @@ def _test_processing_correctness( model_config, tokenizer=cached_tokenizer_from_config(model_config), ) - # Ensure that it can fit all of the data - cache = ProcessingCache(capacity_gb=2048) + cache = MultiModalProcessorOnlyCache(model_config) processing_info = factories.info(ctx) supported_mm_limits = processing_info.get_supported_mm_limits() @@ -102,7 +104,7 @@ def _test_processing_correctness( partial(random_video, rng, min_frames=2, - max_frames=8, + max_frames=16, min_wh=128, max_wh=256), "audio": @@ -160,8 +162,10 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { + "donut": False, "mllama": False, "ovis": False, + "ovis2_5": False, "paligemma": False, "ultravox": False, "whisper": False, @@ -267,23 +271,30 @@ def _test_processing_correctness_one( "CohereForAI/aya-vision-8b", "Salesforce/blip2-opt-2.7b", "facebook/chameleon-7b", + "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", + "naver-clova-ix/donut-base-finetuned-docvqa", + "baidu/ERNIE-4.5-VL-28B-A3B-PT", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", "google/gemma-3n-E2B-it", "zai-org/glm-4v-9b", "zai-org/GLM-4.1V-9B-Thinking", + "zai-org/GLM-4.5V", "ibm-granite/granite-speech-3.3-2b", "h2oai/h2ovl-mississippi-800m", + "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", + "HuggingFaceM4/Idefics3-8B-Llama3", "internlm/Intern-S1", "OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL3-1B", - "HuggingFaceM4/Idefics3-8B-Llama3", - "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "OpenGVLab/InternVL3_5-1B", + "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", + "OpenGVLab/InternVL3_5-30B-A3B", + "Kwai-Keye/Keye-VL-8B-Preview", "moonshotai/Kimi-VL-A3B-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", @@ -301,6 +312,7 @@ def _test_processing_correctness_one( "AIDC-AI/Ovis1.6-Gemma2-9B", "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", + "AIDC-AI/Ovis2.5-2B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", "microsoft/Phi-3.5-vision-instruct", @@ -312,11 +324,15 @@ def _test_processing_correctness_one( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2.5-Omni-3B", + "YannQi/R-4B", "Skywork/Skywork-R1V-38B", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "stepfun-ai/step3", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", "omni-research/Tarsier-7b", "omni-research/Tarsier2-Recap-7b", + "mistralai/Voxtral-Mini-3B-2507", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) @@ -370,10 +386,16 @@ def _assert_inputs_equal( if ignore_mm_keys is None: ignore_mm_keys = set() - assert "mm_kwargs" in a and "mm_kwargs" in b, msg + a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"} + b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"} + + assert a_rest == b_rest, msg + + a_data = a["mm_kwargs"].get_data() + b_data = b["mm_kwargs"].get_data() for key in ignore_mm_keys: - a["mm_kwargs"].pop(key, None) - b["mm_kwargs"].pop(key, None) + a_data.pop(key, None) + b_data.pop(key, None) - assert a == b, msg + assert a_data == b_data, msg diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index a6d900ec5d895..a49842e1099c2 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -45,7 +45,8 @@ def test_processor_override( video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token) video_tok_count = processed_inputs["prompt_token_ids"].count( video_token_id) - grid_t, _, _ = processed_inputs["mm_kwargs"]["video_grid_thw"][0] + grid_t, _, _ = processed_inputs["mm_kwargs"].get_data( + )["video_grid_thw"][0] assert grid_t == expected_grid_t assert video_tok_count == expected_toks_per_frame * grid_t diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 76e4acc67d4d5..1adfe21352c41 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -108,7 +108,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index c3e2841a8f060..e4f25f5ac7123 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -68,7 +68,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index 5e14f0f9964d6..bea4f43567eee 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -51,14 +51,14 @@ def test_processor_override( prompt = encode_tokens(tokenizer, prompt) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) - mm_kwargs = processed_inputs["mm_kwargs"] + mm_data = processed_inputs["mm_kwargs"].get_data() # place holder replacements prompt_token_ids = processed_inputs["prompt_token_ids"] assert prompt_token_ids.count(config.boi_token_index) == num_imgs assert prompt_token_ids.count(config.eoi_token_index) == num_imgs assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs - aspect_ratios = mm_kwargs["aspect_ratios"] + aspect_ratios = mm_data["aspect_ratios"] num_x_separators = num_y_separators = 0 for tiles_y, tiles_x in aspect_ratios: if tiles_x * tiles_y > 1: @@ -80,6 +80,6 @@ def test_processor_override( num_patches_per_chunk = processor.info.get_patch_per_chunk( config.vision_config) assert prompt_token_ids.count(config.image_token_index) \ - == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk - assert mm_kwargs["pixel_values"].shape[0] \ - == mm_kwargs["patches_per_image"].sum() + == sum(mm_data["patches_per_image"]) * num_patches_per_chunk + assert len(mm_data["pixel_values"]) \ + == sum(mm_data["patches_per_image"]) diff --git a/tests/models/multimodal/processing/test_mllama.py b/tests/models/multimodal/processing/test_mllama.py index a6b20a1e3678e..b42d3f89f3cbf 100644 --- a/tests/models/multimodal/processing/test_mllama.py +++ b/tests/models/multimodal/processing/test_mllama.py @@ -49,18 +49,18 @@ def test_profiling( encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids) ] * max_num_seqs - mm_kwargs = processor.apply( + mm_data = processor.apply( prompt=dummy_mm_data.prompt, mm_data=dummy_mm_data.mm_data, hf_processor_mm_kwargs=dict(), - )["mm_kwargs"] + )["mm_kwargs"].get_data() # Get the actual number of encoder tokens for each sample. # Because attn_metadata.encoder_seq_lens only counts the last # group of images for each sample, which is used to cheat the # block manager to allocate blocks for those images only. # See MllamaMultiModalProcessor for more details. - num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")] + num_tiles = [[t] for t in mm_data.pop("num_tiles")] num_tokens_per_tile = calc_token_per_chunk(image_size) actual_encoder_seq_lens = [ sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index f3871b60c3f64..3be77b5da63f2 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int): hf_config = ctx.get_hf_config(Llama4Config) - mm_kwargs = processor.apply( + mm_data = processor.apply( prompt=dummy_mm_data.prompt, mm_data=dummy_mm_data.mm_data, hf_processor_mm_kwargs=dict(), - )["mm_kwargs"] + )["mm_kwargs"].get_data() image_size = hf_config.vision_config.image_size patch_size = hf_config.vision_config.patch_size downsample_ratio = int( round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))) tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio - chunks_per_image = prod(mm_kwargs["patches_per_image"]) + chunks_per_image = prod(mm_data["patches_per_image"]) total_num_patches = chunks_per_image * tokens_per_patch - num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][ - 0][1] # x-y seperator tokens + num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][ + 1] # x-y seperator tokens total_tokens = total_num_patches.item() + num_tiles.item( ) + 3 # image start, image, image end diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py index 6fbbab0d26124..d9f1965a053df 100644 --- a/tests/models/multimodal/processing/test_nemotron_vl.py +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -70,7 +70,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<image>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape) assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index 9d1cd183387bc..985f4188fdb66 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -48,7 +48,8 @@ def test_processor_override( hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values"].shape assert img_tok_count == expected_toks_per_img * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py new file mode 100644 index 0000000000000..2d8cd49edc73b --- /dev/null +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -0,0 +1,264 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from functools import partial +from typing import Any, Union +from unittest.mock import patch + +import numpy as np +import pytest +from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, + UserMessage) +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from PIL import Image + +from vllm.config import ModelConfig +from vllm.engine.llm_engine import LLMEngine as V0LLMEngine +from vllm.inputs import InputProcessingContext +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs) +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads +from vllm.v1.core.kv_cache_utils import get_kv_cache_config +from vllm.v1.engine.core import EngineCore as V1EngineCore + +from ....conftest import VllmRunner +from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS +from ...utils import dummy_hf_overrides + +ARCH_TO_SKIP = { + "MolmoForCausalLM": "incompatible requirements", +} +ARCH_NEEDS_EXTRAS = [ + "InternVLChatModel", + "Idefics3ForConditionalGeneration", + "LlavaForConditionalGeneration", + "MiniCPMV", + "PaliGemmaForConditionalGeneration", +] +REPO_ID_TO_SKIP = { + "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test", + # FIXME(Isotr0py): enable GPT-OSS based InternVL3.5 model + # after support PP for GPT-OSS + "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview": "Broken model", +} + +ImageInput = list[Image.Image] +VideoInput = Union[list[Image.Image], list[np.ndarray], + list[tuple[np.ndarray, dict[str, Any]]]] +AudioInput = list[tuple[np.ndarray, int]] + + +def _resize_data(_data: Union[Image.Image, np.ndarray], + size_factor: float) -> Union[Image.Image, np.ndarray]: + assert size_factor <= 1, "Size factor must be less than 1" + # Image input + if isinstance(_data, Image.Image): + W, H = _data.width, _data.height + W, H = map(lambda x: int(x * size_factor), (W, H)) + return _data.resize((W, H)) + # Video input with PIL Images + elif is_list_of(_data, Image.Image): + W, H = next(iter(_data)).width, next(iter(_data)).height + T = len(_data) + T, W, H = map(lambda x: max(int(x * size_factor), 1), (T, W, H)) + return [d.resize((W, H)) for d in _data[:T]] + # Video input with numpy arrays + elif isinstance(_data, np.ndarray) and _data.ndim >= 4: + T, H, W, C = _data.shape[-4:] + T, H, W = map(lambda x: max(int(x * size_factor), 1), (T, H, W)) + return _data[..., :T, :H, :W, :C] + # Audio input + elif isinstance(_data, np.ndarray) and _data.ndim == 1: + return _data[:int(len(_data) * size_factor)] + raise AssertionError("This line should be unreachable.") + + +def resize_mm_data( + data: Union[ImageInput, VideoInput, AudioInput], + size_factors: tuple[float, + ...]) -> Union[ImageInput, VideoInput, AudioInput]: + size_factors = size_factors[:len(data)] + if is_list_of(data, (Image.Image, np.ndarray, list)): + return [_resize_data(d, s) for d, s in zip(data, size_factors)] + elif is_list_of(data, tuple): + return [(_resize_data(d, s), meta) + for (d, meta), s in zip(data, size_factors)] + raise ValueError("Unsupported multimodal data type.") + + +def create_batched_mm_kwargs( + model_config: ModelConfig, + processor: BaseMultiModalProcessor, + size_factors: tuple[float, ...] = (1.0, 0.5, 0.25), +) -> Iterable[tuple[str, int, BatchedTensorInputs]]: + processing_info = processor.info + dummy_inputs = processor.dummy_inputs + supported_mm_limits = processing_info.get_supported_mm_limits() + mm_counts = { + modality: 3 if limit is None else limit + for modality, limit in supported_mm_limits.items() + } + processor_inputs = dummy_inputs.get_dummy_processor_inputs( + seq_len=model_config.max_model_len, + mm_counts=mm_counts, + ) + mm_data = processor_inputs.mm_data + resized_mm_data = { + modality: resize_mm_data(data, size_factors) + for modality, data in mm_data.items() + } + # Mistral chat outputs tokens directly, rather than text prompts + if model_config.tokenizer_mode == "mistral": + images = resized_mm_data.get("image", []) + request = ChatCompletionRequest(messages=[ + UserMessage(content=[ + TextChunk(text=""), + *(ImageChunk(image=image) for image in images), + ]), + ]) + tokenizer = processing_info.get_tokenizer() + res = tokenizer.mistral.encode_chat_completion(request) + prompt = res.tokens + else: + prompt = processor_inputs.prompt + mm_kwargs = processor.apply( + prompt=prompt, + mm_data=resized_mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + tokenization_kwargs=processor_inputs.tokenization_kwargs, + )["mm_kwargs"] + items = [ + item for modality in supported_mm_limits + for item in mm_kwargs[modality] + ] + return group_mm_kwargs_by_modality(items) + + +def get_model_id_to_test( + model_arch_list: Iterable[str]) -> list[tuple[str, str]]: + filtered_results = [] + for model_arch in model_arch_list: + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + if model_info.extras and model_arch in ARCH_NEEDS_EXTRAS: + available_repos = list( + map(lambda model_id: (model_arch, model_id), + [model_info.default, *model_info.extras.values()])) + filtered_results.extend(available_repos) + else: + filtered_results.append((model_arch, model_info.default)) + return filtered_results + + +@pytest.mark.parametrize( + "model_arch, model_id", + get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys())) +def test_model_tensor_schema(model_arch: str, model_id: str, + vllm_runner: type[VllmRunner], monkeypatch): + if model_arch in ARCH_TO_SKIP: + pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") + if model_id in REPO_ID_TO_SKIP: + pytest.skip(f"Skipping {model_id} due to {REPO_ID_TO_SKIP[model_id]}") + + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip", + check_max_version=False) + + hf_overrides_fn = partial(dummy_hf_overrides, + model_arch=model_arch, + exist_overrides=model_info.hf_overrides) + + model_config = ModelConfig( + model_id, + tokenizer=model_info.tokenizer or model_id, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] + + if not any( + hasattr(model_cls, f"_parse_and_validate_{m}_input") + for m in ["image", "video", "audio"]): + pytest.skip(f"{model_arch} does not support tensor schema validation.") + + ctx = InputProcessingContext( + model_config, + tokenizer=cached_tokenizer_from_config(model_config), + ) + processing_info = factories.info(ctx) + supported_mm_limits = processing_info.get_supported_mm_limits() + limit_mm_per_prompt = { + modality: 3 if limit is None else limit + for modality, limit in supported_mm_limits.items() + } + + # Avoid calling model.forward() + def _initialize_kv_caches_v0(self) -> None: + self.cache_config.num_gpu_blocks = 0 + self.cache_config.num_cpu_blocks = 0 + + def _initialize_kv_caches_v1(self, vllm_config): + kv_cache_specs = self.model_executor.get_kv_cache_specs() + scheduler_kv_cache_config = get_kv_cache_config( + vllm_config, + kv_cache_specs[0], + 10 * GiB_bytes, + ) + + # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config + return 1, 0, scheduler_kv_cache_config + + with (patch.object(V0LLMEngine, "_initialize_kv_caches", + _initialize_kv_caches_v0), + patch.object(V1EngineCore, "_initialize_kv_caches", + _initialize_kv_caches_v1), monkeypatch.context() as m): + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + if model_info.v0_only: + m.setenv("VLLM_USE_V1", "0") + + # TODO(Isotr0py): Can we avoid initializing engine? + with ( + set_default_torch_num_threads(1), + vllm_runner( + model_id, + tokenizer_name=model_info.tokenizer, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + max_model_len=model_info.max_model_len, + load_format="dummy", + hf_overrides=hf_overrides_fn, + limit_mm_per_prompt=limit_mm_per_prompt, + enforce_eager=True, + ) as vllm_model, + ): + model_config = vllm_model.llm.llm_engine.model_config + llm_engine = vllm_model.llm.llm_engine + + if hasattr(llm_engine, "processor"): + # v1 processor + mm_registry = llm_engine.processor.mm_registry + else: + # v0 input_preprocessor + mm_registry = llm_engine.input_preprocessor.mm_registry + + processor = mm_registry.create_processor(model_config) + + def validate_model_input(model, modality: str, + mm_kwargs: MultiModalKwargs): + method_name = f"_parse_and_validate_{modality}_input" + if hasattr(model, method_name): + getattr(model, method_name)(**mm_kwargs) + + for modality, _, mm_kwargs in create_batched_mm_kwargs( + model_config, processor): + valid_func = partial(validate_model_input, + modality=modality, + mm_kwargs=mm_kwargs) + vllm_model.apply_model(valid_func) diff --git a/tests/models/multimodal/test_tensor_schema.py b/tests/models/multimodal/test_tensor_schema.py deleted file mode 100644 index 92390d8c2f7ee..0000000000000 --- a/tests/models/multimodal/test_tensor_schema.py +++ /dev/null @@ -1,156 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from functools import partial -from unittest.mock import patch - -import pytest - -from vllm.config import ModelConfig -from vllm.engine.llm_engine import LLMEngine as V0LLMEngine -from vllm.inputs import InputProcessingContext -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.processing import BaseMultiModalProcessor -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils import GiB_bytes, set_default_torch_num_threads -from vllm.v1.core.kv_cache_utils import get_kv_cache_config -from vllm.v1.engine.core import EngineCore as V1EngineCore - -from ...conftest import VllmRunner -from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS -from ..utils import dummy_hf_overrides - -ARCH_TO_SKIP = { - "MolmoForCausalLM": "incompatible requirements", - "MiniMaxVL01ForConditionalGeneration": "broken model", -} - - -def create_batched_mm_kwargs( - model_config: ModelConfig, - processor: BaseMultiModalProcessor, -) -> MultiModalKwargs: - processing_info = processor.info - dummy_inputs = processor.dummy_inputs - supported_mm_limits = processing_info.get_supported_mm_limits() - mm_counts = { - modality: 3 if limit is None else limit - for modality, limit in supported_mm_limits.items() - } - processor_inputs = dummy_inputs.get_dummy_processor_inputs( - seq_len=model_config.max_model_len, - mm_counts=mm_counts, - ) - mm_kwargs = processor.apply( - prompt=processor_inputs.prompt, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, - tokenization_kwargs=processor_inputs.tokenization_kwargs, - )["mm_kwargs"] - mm_kwargs = MultiModalKwargs.batch([mm_kwargs]) - return mm_kwargs - - -@pytest.mark.core_model -@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys())) -def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner], - monkeypatch): - if model_arch in ARCH_TO_SKIP: - pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") - - model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) - model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip", - check_max_version=False) - - model_id = model_info.default - - hf_overrides_fn = partial(dummy_hf_overrides, - model_arch=model_arch, - exist_overrides=model_info.hf_overrides) - - model_config = ModelConfig( - model_id, - tokenizer=model_info.tokenizer or model_id, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - hf_overrides=model_info.hf_overrides, - ) - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] - - if not any( - hasattr(model_cls, f"_parse_and_validate_{m}_input") - for m in ["image", "video", "audio"]): - pytest.skip(f"{model_arch} does not support tensor schema validation.") - - ctx = InputProcessingContext( - model_config, - tokenizer=cached_tokenizer_from_config(model_config), - ) - processing_info = factories.info(ctx) - supported_mm_limits = processing_info.get_supported_mm_limits() - limit_mm_per_prompt = { - modality: 3 if limit is None else limit - for modality, limit in supported_mm_limits.items() - } - - # Avoid calling model.forward() - def _initialize_kv_caches_v0(self) -> None: - self.cache_config.num_gpu_blocks = 0 - self.cache_config.num_cpu_blocks = 0 - - def _initialize_kv_caches_v1(self, vllm_config): - kv_cache_specs = self.model_executor.get_kv_cache_specs() - scheduler_kv_cache_config = get_kv_cache_config( - vllm_config, - kv_cache_specs[0], - 10 * GiB_bytes, - ) - - # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config - return 1, 0, scheduler_kv_cache_config - - with (patch.object(V0LLMEngine, "_initialize_kv_caches", - _initialize_kv_caches_v0), - patch.object(V1EngineCore, "_initialize_kv_caches", - _initialize_kv_caches_v1), monkeypatch.context() as m): - m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - if model_info.v0_only: - m.setenv("VLLM_USE_V1", "0") - - with ( - set_default_torch_num_threads(1), - vllm_runner( - model_id, - tokenizer_name=model_info.tokenizer, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - max_model_len=model_info.max_model_len, - load_format="dummy", - hf_overrides=hf_overrides_fn, - limit_mm_per_prompt=limit_mm_per_prompt, - enforce_eager=True, - ) as vllm_model, - ): - model_config = vllm_model.llm.llm_engine.model_config - llm_engine = vllm_model.llm.llm_engine - - if hasattr(llm_engine, "processor"): - # v1 processor - mm_registry = llm_engine.processor.mm_registry - else: - # v0 input_preprocessor - mm_registry = llm_engine.input_preprocessor.mm_registry - - processor = mm_registry.create_processor(model_config) - mm_kwargs = create_batched_mm_kwargs(model_config, processor) - - def validate_model_input(model): - for modality in ("audio", "image", "video"): - method_name = f"_parse_and_validate_{modality}_input" - if hasattr(model, method_name): - getattr(model, method_name)(**mm_kwargs) - - vllm_model.apply_model(validate_model_input) \ No newline at end of file diff --git a/tests/models/quantization/test_aqlm.py b/tests/models/quantization/test_aqlm.py deleted file mode 100644 index de6851e2fc282..0000000000000 --- a/tests/models/quantization/test_aqlm.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - -from tests.quantization.utils import is_quant_method_supported -from vllm.platforms import current_platform - -# These ground truth generations were generated using `transformers==4.38.1 -# aqlm==1.1.0 torch==2.2.0` -# and the below code: -# ```python -# from transformers import AutoTokenizer, AutoModelForCausalLM -# model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" -# quantized_model = AutoModelForCausalLM.from_pretrained(model_id, -# torch_dtype="auto", device_map="cuda").cuda() -# tokenizer = AutoTokenizer.from_pretrained(model_id) -# outputs = [] -# for prompt in example_prompts: -# input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda") -# hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32) -# outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:])) -# print(outputs) -# ``` -ground_truth_generations = [ - '\n### Features\n\n- **High-throughput**: v', - 'The major milestones in the development of artificial intelligence from ' - '195', - 'Compare and contrast artificial intelligence with human intelligence in ' - 'terms of processing information. The', - 'Explain the difference between supervised and unsupervised learning.' - '\nExplain', - 'Write a short story about a robot that dreams for the first time. The', - 'Analyze the impact of the COVID-19 pandemic on global economic', - 'The Mona Lisa is a painting by Leonardo da Vinci, and it', - 'The early bird catches the worm.\nThe early bird catches the' -] - - -@pytest.mark.skipif(not is_quant_method_supported("aqlm") - or current_platform.is_rocm() - or not current_platform.is_cuda(), - reason="AQLM is not supported on this GPU type.") -@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("num_logprobs", [1]) -def test_models( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - # loop through the prompts to compare against the ground truth generations - for prompt_idx in range(len(example_prompts)): - vllm_output_ids, vllm_output_str, vllm_logprobs = vllm_outputs[ - prompt_idx] - - print("Prompt: ", repr(example_prompts[prompt_idx])) - print("Reference output:", repr(ground_truth_generations[prompt_idx])) - print("Output output: ", repr(vllm_output_str)) - assert vllm_output_str == ground_truth_generations[prompt_idx] diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index 10914abf9ad3d..afc27b6e0566e 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -32,7 +32,7 @@ from ..utils import check_logprobs_close # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @@ -57,9 +57,6 @@ def test_models( numerical sensitive kernels. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): pytest.skip( f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") diff --git a/tests/models/registry.py b/tests/models/registry.py index 3efc9a99ea415..ee546e7af85c6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -196,7 +196,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { {"alias": "gpt2"}), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder", extras={"tiny": "bigcode/tiny_starcoder_py"}, # noqa: E501 - min_transformers_version="4.55.1"), + min_transformers_version="4.55.1", + transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501 "GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"}), "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", @@ -214,9 +215,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124", trust_remote_code=True, is_available_online=False), - "HCXVisionForCausalLM": _HfExamplesInfo( - "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", - trust_remote_code=True), "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", trust_remote_code=True), "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", @@ -232,6 +230,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "tiny": "ai21labs/Jamba-tiny-dev", "random": "ai21labs/Jamba-tiny-random", # noqa: E501 }), + "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B", + min_transformers_version="4.54"), "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 @@ -292,13 +292,15 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), + "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 + trust_remote_code=True, + is_available_online=False), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", - trust_remote_code=True, - is_available_online=False), + trust_remote_code=True), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct", trust_remote_code=True), "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", @@ -315,6 +317,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), + "MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501 + hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501 } _EMBEDDING_EXAMPLE_MODELS = { @@ -392,6 +396,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 + trust_remote_code=True), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 @@ -402,20 +408,27 @@ _MULTIMODAL_EXAMPLE_MODELS = { hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501 "Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V", - is_available_online=False), # noqa: E501 + min_transformers_version="4.56"), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", trust_remote_code=True, extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible."), # noqa: E501 + "HCXVisionForCausalLM": _HfExamplesInfo("naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", # noqa: E501 + trust_remote_code=True), + "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 + {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, # noqa: E501 + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55"), # noqa: E501 + "InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1", + trust_remote_code=True), # noqa: E501 "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", extras={"2B": "OpenGVLab/InternVL2-2B", - "3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501 + "3.0": "OpenGVLab/InternVL3-1B", # noqa: E501 + "3.5-qwen3": "OpenGVLab/InternVL3_5-1B", # noqa: E501 + "3.5-qwen3moe": "OpenGVLab/InternVL3_5-30B-A3B", # noqa: E501 + "3.5-gptoss": "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview"}, # noqa: E501 trust_remote_code=True), - "InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1", - trust_remote_code=True), - "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 - {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 "KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501 trust_remote_code=True), "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 @@ -438,7 +451,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", - extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501 + extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501 trust_remote_code=True), "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 trust_remote_code=True, @@ -455,8 +468,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 trust_remote_code=True), "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, + max_transformers_version="4.53", + transformers_version_reason="HF model is not compatible", # noqa: E501 extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 + "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", + trust_remote_code=True), "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", @@ -480,12 +497,15 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 + "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", + trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", trust_remote_code=True), - "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 + "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501 + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55"), # noqa: E501 "Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3", - trust_remote_code=True, - is_available_online=False), + trust_remote_code=True), "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 trust_remote_code=True), "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501 @@ -498,6 +518,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { is_available_online=False, ), # [Encoder-decoder] + "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 + extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 @@ -520,6 +543,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 trust_remote_code=True), + "EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random", + speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501 + trust_remote_code=True), "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", @@ -543,6 +569,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { is_available_online=False, speculative_model="openbmb/MiniCPM-2B-sft-bf16", tokenizer="openbmb/MiniCPM-2B-sft-bf16"), + "ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", + trust_remote_code=True, + speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"), "Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5", min_transformers_version="4.54", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index f06b34285eaea..b4d516233b4bf 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -38,11 +38,6 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_arch=model_arch, exist_overrides=model_info.hf_overrides) - if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): - from vllm.model_executor.models.llama4 import Llama4ForCausalLM - from vllm.model_executor.models.registry import ModelRegistry - ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) - # Avoid calling model.forward() def _initialize_kv_caches_v0(self) -> None: self.cache_config.num_gpu_blocks = 0 @@ -95,6 +90,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): + if model_arch == "Lfm2ForCausalLM": + pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 3feee01dadf73..77e3732cd06c6 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -255,8 +255,8 @@ async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch): pass end = time.perf_counter() - assert end - start < 60, ( - "Expected vLLM to gracefully shutdown in <60s " + assert end - start < 100, ( + "Expected vLLM to gracefully shutdown in <100s " "if there is an error in the startup.") diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index e07b73bd257d6..44c05db2278f7 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -1,32 +1,64 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np import pytest import torch -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.multimodal.cache import (MultiModalCache, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + processor_cache_from_config, + receiver_cache_from_config) +from vllm.multimodal.hasher import MultiModalHasher +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalSharedField) +from vllm.multimodal.processing import PromptInsertion +from vllm.multimodal.registry import MultiModalRegistry -def _dummy_elem(modality: str, key: str, size: int): +def _dummy_elem( + modality: str, + key: str, + size: int, + *, + rng: Optional[np.random.RandomState] = None, +): + if rng is None: + data = torch.empty((size, ), dtype=torch.int8) + else: + data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8)) + return MultiModalFieldElem( modality=modality, key=key, - data=torch.empty((size, ), dtype=torch.int8), + data=data, field=MultiModalSharedField(1), ) -def _dummy_item(modality: str, size_by_key: dict[str, int]): +def _dummy_item( + modality: str, + size_by_key: dict[str, int], + *, + rng: Optional[np.random.RandomState] = None, +): return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() + _dummy_elem(modality, key, size, rng=rng) + for key, size in size_by_key.items() ]) -def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): - return MultiModalKwargs.from_items([ - _dummy_item(modality, size_by_key) +def _dummy_items( + size_by_key_modality: dict[str, dict[str, int]], + *, + rng: Optional[np.random.RandomState] = None, +): + return MultiModalKwargsItems.from_seq([ + _dummy_item(modality, size_by_key, rng=rng) for modality, size_by_key in size_by_key_modality.items() ]) @@ -37,7 +69,8 @@ def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): [ (_dummy_item("a", {"a1": 100}), 100), (_dummy_item("a", {"a1": 100, "a2": 110}), 210), - (_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 + (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 + (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501 ], ) # yapf: enable @@ -47,5 +80,139 @@ def test_cache_item_size(item, expected_size): cache[""] = item assert cache.currsize == expected_size - cache[""] = MultiModalCacheItemMetadata.wraps(item) + prompt_update = PromptInsertion("dummy", "target", "insertion") \ + .resolve(0) + + cache[""] = MultiModalProcessorCacheItem(item, [prompt_update]) assert cache.currsize == expected_size + + cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update]) + assert cache.currsize == expected_size + + +def _create_vllm_config( + *, + mm_processor_cache_gb: float, + enable_ipc: bool, +): + return VllmConfig( + model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb), + parallel_config=ParallelConfig( + data_parallel_size=1 if enable_ipc else 2), + ) + + +def _compare_caches( + config_0: VllmConfig, + config_1: VllmConfig, + *, + item_capacity: int = 8, + hit_rate: float = 0.5, + max_items_per_iter: int = 3, + is_cached_calls_per_iter: int, + n_iter: int = 100, + seed: int = 0, +): + mm_registry = MultiModalRegistry() + cache_0_p0 = processor_cache_from_config(config_0, mm_registry) + cache_0_p1 = receiver_cache_from_config(config_0, mm_registry) + cache_1_p0 = processor_cache_from_config(config_1, mm_registry) + cache_1_p1 = receiver_cache_from_config(config_1, mm_registry) + + cache_size_gb = max( + config_0.model_config.mm_processor_cache_gb, + config_1.model_config.mm_processor_cache_gb, + ) + item_size_gb = int(cache_size_gb / item_capacity) + + rng = np.random.RandomState(seed) + all_items = [ + _dummy_item("item", {"key": item_size_gb}, rng=rng) + for _ in range(int(item_capacity / hit_rate)) + ] + all_hashes = [ + MultiModalHasher.hash_kwargs(item=item.get_data()) + for item in all_items + ] + + # Should not be used since there is nothing to convert to text + prompt_update = PromptInsertion("dummy", "target", "insertion") + + for it in range(n_iter): + num_items_to_select = rng.randint(0, max_items_per_iter) + item_idxs_to_select = rng.choice(len(all_items), num_items_to_select) + + selected_items = [all_items[idx] for idx in item_idxs_to_select] + selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select] + + if cache_0_p0 is None: + cache_0_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_0_p0.is_cached(selected_hashes) + cache_0_p0_out = [ + item for item, _ in cache_0_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_1_p0 is None: + cache_1_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_1_p0.is_cached(selected_hashes) + cache_1_p0_out = [ + item for item, _ in cache_1_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_0_p1 is None: + cache_0_p1_out = cache_0_p0_out + else: + cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, + selected_hashes) + + if cache_1_p1 is None: + cache_1_p1_out = cache_1_p0_out + else: + cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, + selected_hashes) + + assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}" + + +@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3]) +def test_ipc_enable_disable_consistency(is_cached_calls_per_iter): + cache_size_gb = 1 / (1 << 20) + + vllm_config_ipc_enabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + vllm_config_ipc_disabled = _create_vllm_config( + mm_processor_cache_gb=0, + enable_ipc=False, + ) + vllm_config_cache_disabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + + _compare_caches( + vllm_config_ipc_enabled, + vllm_config_ipc_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_ipc_disabled, + vllm_config_cache_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_cache_disabled, + vllm_config_ipc_enabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) diff --git a/tests/multimodal/test_hasher.py b/tests/multimodal/test_hasher.py index 75a233c2567cb..2751e38760e17 100644 --- a/tests/multimodal/test_hasher.py +++ b/tests/multimodal/test_hasher.py @@ -45,10 +45,11 @@ def test_hash_collision_image_transpose(): assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2) -def test_hash_collision_tensor_shape(): +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_hash_collision_tensor_shape(dtype): # The hash should be different though the data is the same when flattened - arr1 = torch.zeros((5, 10, 20, 3)) - arr2 = torch.zeros((10, 20, 5, 3)) + arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype) + arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype) hasher = MultiModalHasher assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index cb489c47fd8fd..6ce5fcfe644bd 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo, PromptReplacement, apply_text_matches, apply_token_matches, find_mm_placeholders, - find_text_matches, find_token_matches, iter_token_matches, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import full_groupby from .utils import random_image @@ -75,12 +73,15 @@ from .utils import random_image ), ], ) +@pytest.mark.parametrize("start_idx", [0, 4, 8]) # yapf: enable -def test_iter_token_matches(token_ids, match_ids, expected): - result = list(iter_token_matches(token_ids, match_ids)) +def test_iter_token_matches(token_ids, match_ids, expected, start_idx): + result = list(iter_token_matches(token_ids, match_ids, + start_idx=start_idx)) # Manually constructed results - assert [item._asdict() for item in result] == expected + assert [item._asdict() for item in result + ] == [item for item in expected if item["start_idx"] >= start_idx] # Invariants match_lens = [end - start for start, end in result] @@ -241,21 +242,23 @@ def test_find_token_matches( # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_token_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_token_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key @@ -388,21 +391,23 @@ def test_find_text_matches( # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_text_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_text_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key @@ -552,39 +557,35 @@ def test_find_update_text( update_type, expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_text_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_text_matches( + mm_prompt_updates = { + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_text_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected # yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ - # Tokenized test cases of `test_find_replace_text` + # Tokenized test cases of `test_find_update_text` # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf ( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], @@ -726,32 +727,28 @@ def test_find_update_tokens( update_type, expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_token_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_token_matches( + mm_prompt_updates = { + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_token_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected # yapf: disable @@ -878,17 +875,11 @@ def test_find_mm_placeholders( mock_tokenizer = cast(AnyTokenizer, object()) mm_prompt_updates = { - key: [update_type(key, [], repl).bind(mock_tokenizer)] + key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] for key, repl in repl_by_key.items() } - result = find_mm_placeholders( - mm_prompt_updates, - prompt, - # Effectively match all occurrences in the prompt - {key: 3 - for key in repl_by_key}, - ) + result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer) # Only displayed on error print("result:", result) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index ea964a54383c9..a028c668c8ab7 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import math import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -20,6 +21,8 @@ from vllm.distributed.parallel_state import (init_distributed_environment, from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, + get_load_balance_assignment, + run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model) from vllm.platforms import current_platform from vllm.utils import get_open_port, update_environment_variables @@ -425,8 +428,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, # Set random seed for reproducibility current_platform.seed_everything(0) - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) torch.set_default_device(device) update_environment_variables({ @@ -463,3 +466,322 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, # Check that the outputs are close (they should be identical) assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," + "expected_grouped_sizes_per_gpu,test_description", + [ + # Empty input + ([], 2, [], [0, 0], [0, 0], "empty input"), + + # Fewer samples than GPUs + ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 + ], "fewer samples than GPUs"), + + # Single GPU + ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), + + # Balanced assignment + ([100, 100, 100, 100 + ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), + + # Unbalanced sizes - this one is trickier since the algorithm is greedy + ([1000, 100, 200, 50], 2, [0, 2, 1, 3 + ], [1, 3], [1000, 350], "unbalanced sizes"), + ], +) +def test_get_load_balance_assignment_cases(sizes, num_gpus, + expected_shuffle_indices, + expected_gpu_sample_counts, + expected_grouped_sizes_per_gpu, + test_description): + """Test get_load_balance_assignment with various input cases.""" + result = get_load_balance_assignment(sizes, num_gpus=num_gpus) + (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result + + # Common assertions for all cases + assert len(shuffle_indices) == len(sizes) + assert len(gpu_sample_counts) == num_gpus + assert len(grouped_sizes_per_gpu) == num_gpus + assert sum(gpu_sample_counts) == len(sizes) + + assert shuffle_indices == expected_shuffle_indices + + assert gpu_sample_counts == expected_gpu_sample_counts + assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu + + +class SimpleMRopeVisionModel(torch.nn.Module): + """A simple vision model for testing mrope functionality.""" + + def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): + super().__init__() + self.spatial_merge_size = spatial_merge_size + self.out_hidden_size = out_hidden_size + self.linear = torch.nn.Linear(768, out_hidden_size) + + def forward(self, pixel_values: torch.Tensor, + grid_thw_list: list[list[int]]): + """Simple forward pass that simulates spatial merging.""" + # Apply linear transformation + embeddings = self.linear(pixel_values) + + # Simulate spatial merging by reducing the number of patches + merge_factor = self.spatial_merge_size * self.spatial_merge_size + + # Group patches and merge spatially + merged_embeddings = [] + start_idx = 0 + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + end_idx = start_idx + num_patches + + # Get patches for this image + image_patches = embeddings[start_idx:end_idx] + + # Simulate spatial merging by averaging groups of patches + merged_patches = num_patches // merge_factor + if merged_patches > 0: + # Reshape and average to simulate merging + reshaped = image_patches[:merged_patches * merge_factor].view( + merged_patches, merge_factor, -1) + merged = reshaped.mean(dim=1) + merged_embeddings.append(merged) + + start_idx = end_idx + + if merged_embeddings: + return torch.cat(merged_embeddings, dim=0) + else: + return torch.empty((0, self.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 3, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_mrope_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_mrope_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, + world_size: int, + batch_size: int, + master_port: int): + """ + Test that run_dp_sharded_mrope_vision_model produces the same results as + calling the model directly. + """ + # Set random seed for reproducibility + current_platform.seed_everything(0) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test data + grid_thw_list = [] + pixel_values_list = [] + + for i in range(batch_size): + # Varying image sizes for better testing + t, h, w = 1, 4 + i, 4 + i + grid_thw_list.append([t, h, w]) + + num_patches = t * h * w + # Create random pixel values for this image + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + # Concatenate all pixel values + pixel_values = torch.cat(pixel_values_list, dim=0) + + # Create a simple mrope vision model + vision_model = SimpleMRopeVisionModel() + + # Run the model directly on the full input (only on rank 0) + if local_rank == 0: + with torch.inference_mode(): + direct_output = vision_model(pixel_values, grid_thw_list) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list) + sharded_output = torch.cat(sharded_output, dim=0) + + # Check that the world size is setup correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Compare outputs (only on rank 0) + if local_rank == 0: + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, + sharded_output, + rtol=1e-5, + atol=1e-5) + + +@multi_gpu_test(num_gpus=2) +def test_run_dp_sharded_mrope_vision_model_empty_input(): + world_size = 2 + mp.spawn( + run_dp_sharded_mrope_vision_model_empty_input_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_empty_input_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with empty input.""" + # Set up distributed environment + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create empty inputs + pixel_values = torch.empty((0, 768)) + grid_thw_list: list[list[int]] = [] + + vision_model = SimpleMRopeVisionModel() + + # Should handle empty input gracefully + with torch.inference_mode(): + output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values, + grid_thw_list) + + assert len(output) == 0 + + +@multi_gpu_test(num_gpus=4) +def test_run_dp_sharded_mrope_vision_model_uneven_load(): + world_size = 4 + mp.spawn( + run_dp_sharded_mrope_vision_model_uneven_load_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_uneven_load_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" + # Set up distributed environment + current_platform.seed_everything(123) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create images with very different sizes + grid_thw_list = [ + [1, 2, 2], # Small: 4 patches + [1, 8, 8], # Large: 64 patches + [1, 3, 3], # Medium: 9 patches + ] + + pixel_values_list = [] + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel() + + # Should handle uneven distribution without errors + with torch.inference_mode(): + output_tuple = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list) + + # Verify output shape is reasonable + merge_factor = vision_model.spatial_merge_size**2 + expected_output_patches = list( + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) + + for i, output in enumerate(output_tuple): + assert output.shape[0] == expected_output_patches[i] + assert output.shape[1] == vision_model.out_hidden_size + + +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): + """Test SimpleMRopeVisionModel with different spatial merge sizes.""" + device = current_platform.device_type + + grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images + pixel_values_list = [] + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768, device=device) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel( + spatial_merge_size=spatial_merge_size).to(device) + + with torch.inference_mode(): + output = vision_model(pixel_values, grid_thw_list) + + # Verify output dimensions based on spatial merging + total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) + merge_factor = spatial_merge_size**2 + expected_output_patches = total_patches // merge_factor + + assert output.shape[0] == expected_output_patches + assert output.shape[1] == vision_model.out_hidden_size diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py deleted file mode 100644 index b940ab416e673..0000000000000 --- a/tests/prefix_caching/test_disable_sliding_window.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the with and without prefix caching. - -Run `pytest tests/prefix_caching/test_prefix_caching.py`. -""" -import pytest - -from vllm import LLM -from vllm.distributed import cleanup_dist_env_and_memory - -MODEL_LEN_LEN = [ - # Example models with sliding window. - ("bigcode/starcoder2-3b", 4096, 16384), - # ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI - - # Confirm model with sliding window works. - # config has "use_sliding_window": false - ("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768), - # config has no sliding window attribute. - ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048), -] - - -@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN) -def test_disable_sliding_window(model_len_len, ): - model, sliding_len, full_len = model_len_len - disabled_llm = LLM(model, disable_sliding_window=True) - disabled_llm.generate("Hi my name is") - model_config = disabled_llm.llm_engine.model_config - assert model_config.max_model_len == sliding_len, ( - "Max len expected to equal sliding_len of %s, but got %s", sliding_len, - model_config.max_model_len) - - del disabled_llm - cleanup_dist_env_and_memory() - - enabled_llm = LLM(model, - enforce_eager=True, - disable_sliding_window=False, - enable_prefix_caching=False) - enabled_llm.generate("Hi my name is") - model_config = enabled_llm.llm_engine.model_config - assert model_config.max_model_len == full_len, ( - "Max len expected to equal full_len of %s, but got %s", full_len, - model_config.max_model_len) - - del enabled_llm - cleanup_dist_env_and_memory() diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py deleted file mode 100644 index 5bf6ed957c74e..0000000000000 --- a/tests/prefix_caching/test_prefix_caching.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the with and without prefix caching. - -Run `pytest tests/prefix_caching/test_prefix_caching.py`. -""" - -from __future__ import annotations - -import pytest - -from tests.conftest import VllmRunner -from tests.core.utils import SchedulerProxy, create_dummy_prompt -from vllm import SamplingParams, TokensPrompt -from vllm.core.scheduler import Scheduler -from vllm.engine.llm_engine import LLMEngine -from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR - -from ..models.utils import check_outputs_equal - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - This module relies on V0 internals, so set VLLM_USE_V1=0. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - -MODELS = [ - "distilbert/distilgpt2", -] - -UNSTABLE_PROMPT_SEQUENCE = [ - ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1), - ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50), - ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95), - ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174), - ([0] * 588) + ([8] * 1539), -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("cached_position", [0, 1]) -@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) -@pytest.mark.parametrize("block_size", [16]) -def test_mixed_requests( - hf_runner, - vllm_runner, - example_prompts, - model: str, - backend: str, - dtype: str, - max_tokens: int, - cached_position: int, - enable_chunked_prefill: bool, - block_size: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Test the case when some sequences have the prefix cache hit - and the others don't. The cached position determines where - the sequence is at among the batch of prefills. - """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - if backend == "XFORMERS" and current_platform.is_rocm(): - pytest.skip("Xformers does not support ROCm/HIP.") - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, backend) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - cached_prompt = example_prompts[cached_position] - with vllm_runner( - model, - dtype=dtype, - enable_prefix_caching=True, - enable_chunked_prefill=enable_chunked_prefill, - block_size=block_size, - ) as vllm_model: - # Run the first prompt so the cache is populated - vllm_outputs = vllm_model.generate_greedy([cached_prompt], - max_tokens) - - # Run all the promopts - greedy_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens) - req_outputs = vllm_model.llm.generate(example_prompts, - greedy_params) - - # Verify number of cached tokens - for i in range(len(req_outputs)): - if i == cached_position: - expected_num_cached_tokens = ( - len(req_outputs[i].prompt_token_ids) // - block_size) * block_size - else: - expected_num_cached_tokens = 0 - assert (req_outputs[i].num_cached_tokens == - expected_num_cached_tokens) - - vllm_outputs = [( - output.prompt_token_ids + list(output.outputs[0].token_ids), - output.prompt + output.outputs[0].text, - ) for output in req_outputs] - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) -def test_unstable_prompt_sequence( - vllm_runner, - backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - if backend == "XFORMERS" and current_platform.is_rocm(): - pytest.skip("Xformers does not support ROCm/HIP.") - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, backend) - - with vllm_runner( - "Qwen/Qwen2.5-0.5B-Instruct", - enable_chunked_prefill=True, - enable_prefix_caching=True, - max_model_len=4096, - ) as vllm_model: - for prompt in UNSTABLE_PROMPT_SEQUENCE: - vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), - SamplingParams(max_tokens=1)) - - -@pytest.mark.parametrize("model", MODELS) -def test_fully_cached_prefill_needs_uncached_token(model): - block_size = 16 - max_num_batched_tokens = 16 - num_output_tokens = 5 - # Make a vllm engine - runner = VllmRunner( - model_name=model, - gpu_memory_utilization=0.7, - enable_chunked_prefill=True, - enforce_eager=True, - enable_prefix_caching=True, - block_size=block_size, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_batched_tokens, - ) - engine: LLMEngine = runner.llm.llm_engine - - scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore - engine.scheduler[0] = scheduler - - # SeqA - seqA_tokens = list(range(2 * block_size)) - seqA, seq_groupA = create_dummy_prompt( - request_id="0", - prompt_tokens=seqA_tokens, - max_tokens=num_output_tokens, - block_size=block_size, - ) - - scheduler.add_seq_group(seq_groupA) - - assert seqA.data.get_num_computed_tokens() == 0 - - # Prefill seqA - while not seqA.is_finished(): - engine.step() - - # seqB - seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1 - seqB, seq_groupB = create_dummy_prompt( - request_id="1", - prompt_tokens=seqB_tokens, - max_tokens=num_output_tokens, - block_size=block_size, - ) - - # seqC is the same as seqA - seqC, seq_groupC = create_dummy_prompt( - request_id="2", - prompt_tokens=seqA_tokens, - max_tokens=num_output_tokens, - block_size=block_size, - ) - - scheduler.add_seq_group(seq_groupB) - scheduler.add_seq_group(seq_groupC) - - # Even seqC is fully cached, it should not be prefilled since we - # require at least 1 uncached token. - engine.step() - - sched_metas, sched_out, _ = scheduler.last_schedule_ret() - assert len(sched_out.scheduled_seq_groups) == 1 - assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == - seq_groupB.request_id) - assert (sched_out.scheduled_seq_groups[0].token_chunk_size == - max_num_batched_tokens) - - # When seqB is finished, seqC could be prefilled. - while not seqB.is_finished(): - engine.step() - sched_metas, sched_out, _ = scheduler.last_schedule_ret() - assert len(sched_out.scheduled_seq_groups) == 1 - assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == - seq_groupB.request_id) - - engine.step() - sched_metas, sched_out, _ = scheduler.last_schedule_ret() - assert len(sched_out.scheduled_seq_groups) == 1 - assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == - seq_groupC.request_id) - assert sched_out.scheduled_seq_groups[0].token_chunk_size == len( - seqA_tokens) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 296743dbfa041..b9774b7ee2631 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -683,3 +683,39 @@ def test_compressed_tensors_nvfp4(vllm_runner, args): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda() + or not current_platform.has_device_capability(90), + reason="W4A8 FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize("args", [ + ("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8) +]) +def test_compressed_tensors_w4a8_fp8(vllm_runner, args): + model, scheme = args + with vllm_runner(model, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): + assert isinstance(proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(proj.scheme, scheme) + + assert proj.weight_packed.dtype is torch.int32 + assert proj.weight_scale.dtype is torch.float8_e4m3fn + assert proj.weight_chan_scale.dtype is torch.float32 + assert proj.scheme.group_size == 128 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index 8cf8402436ff5..1843bffd21159 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -22,22 +22,12 @@ class ModelPair: MODEL_ARG_EXPTYPES = [ # AUTOGPTQ # compat: autogptq <=0.7.1 is_marlin_format: bool - # Model Serialized in Marlin Format should always use Marlin kernel. - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"), # Model Serialized in Exllama Format. ("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"), # compat: autogptq >=0.8.0 use checkpoint_format: str - # Model Serialized in Marlin Format should always use Marlin kernel. - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"), # Model Serialized in Exllama Format. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 0b37c83c92c2a..d781f462b4ad7 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -38,8 +38,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10) print(outputs[0][1]) @@ -90,8 +89,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10) print(outputs[0][1]) diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index 11f78a23bb4c0..b24964a9d0a9f 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -11,7 +11,6 @@ import torch from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod) -from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod) @@ -19,9 +18,7 @@ PROMPT = "On the surface of Mars, we found" MODELS_QUANT = [ ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True), - ("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), - ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False) ] @@ -41,8 +38,7 @@ def test_lm_head( lm_head_layer = model.lm_head if lm_head_quantized: assert isinstance(lm_head_layer.quant_method, - (GPTQLinearMethod, GPTQMarlinLinearMethod, - MarlinLinearMethod)) + (GPTQLinearMethod, GPTQMarlinLinearMethod)) else: assert isinstance(lm_head_layer.quant_method, UnquantizedEmbeddingMethod) @@ -50,5 +46,5 @@ def test_lm_head( vllm_model.apply_model(check_model) print( - vllm_model.generate_greedy(prompts=["Hello my name is"], + vllm_model.generate_greedy(["Hello my name is"], max_tokens=10)[0][1]) diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index bdf48c7687b25..cc9a88a255f9f 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -67,6 +67,59 @@ def test_beam_search_single_input( f"vLLM: {vllm_output_ids}") +@pytest.mark.skip_v1 # FIXME: This fails on V1 right now. +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", BEAM_WIDTHS) +def test_beam_search_with_concurrency_limit( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + # example_prompts[1]&[3]&[7] fails due to unknown reason even without + # concurency limit. skip them for now. + example_prompts = (example_prompts[:8]) + concurrency_limit = 2 + assert len(example_prompts) > concurrency_limit + with vllm_runner(model, dtype=dtype) as vllm_model: + outputs_with_limit = vllm_model.generate_beam_search( + example_prompts, + beam_width, + max_tokens, + concurrency_limit=concurrency_limit) + outputs_without_limit = [] + + for i in range(0, len(example_prompts), concurrency_limit): + outputs_without_limit.extend( + vllm_model.generate_beam_search( + example_prompts[i:i + concurrency_limit], beam_width, + max_tokens)) + + correct = True + for i in range(len(example_prompts)): + output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i] + output_ids_without_limit, output_texts_without_limit = ( + outputs_without_limit[i]) + for j, (text_with_limit, text_without_limit) in enumerate( + zip(output_texts_with_limit, output_texts_without_limit)): + print(f">>>{j}-th with limit output:") + print(text_with_limit) + print(f">>>{j}-th without limit output:") + print(text_without_limit) + assert len(output_ids_with_limit) == len(output_ids_without_limit) + for j in range(len(output_ids_with_limit)): + if output_ids_with_limit[j] != output_ids_without_limit[j]: + print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n" + f"-limit: {output_ids_without_limit}") + correct = False + assert correct + + @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index c734c8514a6da..1b019be9e56dc 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, - SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + SequenceData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -98,3 +99,38 @@ def test_sequence_group_stage(): assert seq_group.is_prefill() is True seq_group.update_num_computed_tokens(1) assert seq_group.is_prefill() is False + + +def test_sequence_intermediate_tensors_equal(): + + class AnotherIntermediateTensors(IntermediateTensors): + pass + + intermediate_tensors = IntermediateTensors({}) + another_intermediate_tensors = AnotherIntermediateTensors({}) + assert intermediate_tensors != another_intermediate_tensors + + empty_intermediate_tensors_1 = IntermediateTensors({}) + empty_intermediate_tensors_2 = IntermediateTensors({}) + assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2 + + different_key_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + difference_key_intermediate_tensors_2 = IntermediateTensors( + {"2": torch.zeros([2, 4], dtype=torch.int32)}) + assert (different_key_intermediate_tensors_1 + != difference_key_intermediate_tensors_2) + + same_key_different_value_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + same_key_different_value_intermediate_tensors_2 = IntermediateTensors( + {"1": torch.zeros([2, 5], dtype=torch.int32)}) + assert (same_key_different_value_intermediate_tensors_1 + != same_key_different_value_intermediate_tensors_2) + + same_key_same_value_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + same_key_same_value_intermediate_tensors_2 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + assert (same_key_same_value_intermediate_tensors_1 == + same_key_same_value_intermediate_tensors_2) diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index 40c3158e9e683..ccb2acf512caf 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -16,7 +16,7 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" +MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" @pytest.fixture(scope="module") @@ -397,7 +397,9 @@ hello world "no_tools", "single_tool", "single_tool_with_content", + "single_tool_multiline_param", "parallel_tools", + "tool_with_typed_params", # Added this test case ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -422,7 +424,7 @@ fahrenheit "state": "TX", "unit": "fahrenheit" }))) - ], ""), + ], None), ('''Sure! Let me check the weather for you.<tool_call> <function=get_current_weather> <parameter=city> @@ -445,6 +447,30 @@ fahrenheit }))) ], "Sure! Let me check the weather for you."), ('''<tool_call> +<function=calculate_area> +<parameter=shape> +rectangle +</parameter> +<parameter=dimensions> +{"width": 10, + "height": 20} +</parameter> +<parameter=precision> +2 +</parameter> +</function> +</tool_call>''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "rectangle", + "dimensions": { + "width": 10, + "height": 20 + }, + "precision": 2 + }))) + ], None), + ('''<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -484,13 +510,36 @@ celsius "state": "FL", "unit": "celsius" }))) - ], ""), + ], None), + # Added tool_with_typed_params test case + ('''Let me calculate that area for you.<tool_call> +<function=calculate_area> +<parameter=shape> +circle +</parameter> +<parameter=dimensions> +{"radius": 15.5} +</parameter> +<parameter=precision> +3 +</parameter> +</function> +</tool_call>''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "circle", + "dimensions": { + "radius": 15.5 + }, + "precision": 3 + }))) + ], "Let me calculate that area for you."), ], ) def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, sample_tools, model_output, expected_tool_calls, expected_content): - """Test incremental streaming behavior""" + """Test incremental streaming behavior including typed parameters""" request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) @@ -539,7 +588,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, "arguments"] += tool_call.function.arguments # Verify final content - assert other_content == expected_content + assert other_content == (expected_content or "") # Handle None case # Verify we got all expected tool calls assert len(tool_states) == len(expected_tool_calls) @@ -559,6 +608,125 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, assert actual_args == expected_args +def test_extract_tool_calls_missing_closing_parameter_tag( + qwen3_tool_parser, sample_tools): + """Test handling of missing closing </parameter> tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = '''Let me check the weather for you: +<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + + # The parser should handle the malformed XML gracefully + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + + # Verify the function name is correct + assert extracted_tool_calls.tool_calls[ + 0].function.name == "get_current_weather" + + # Verify the arguments are parsed despite the missing closing tag + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert "city" in args + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + # Check that content before the tool call is preserved + assert "Let me check the weather for you:" in extracted_tool_calls.content + + +def test_extract_tool_calls_streaming_missing_closing_tag( + qwen3_tool_parser, qwen3_tokenizer, sample_tools): + """Test streaming with missing closing </parameter> tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = '''Let me check the weather for you: +<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} + + for delta_message in stream_delta_message_generator( + qwen3_tool_parser, qwen3_tokenizer, model_output, request): + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify content was streamed + assert "Let me check the weather for you:" in other_content + + # Verify we got the tool call + assert len(tool_states) == 1 + state = tool_states[0] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == "get_current_weather" + + # Verify arguments were parsed correctly despite missing closing tag + assert state["arguments"] is not None + args = json.loads(state["arguments"]) + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, qwen3_tokenizer, sample_tools): diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py new file mode 100644 index 0000000000000..c276a598aa68c --- /dev/null +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json +from collections.abc import Generator +from typing import Optional + +import pytest + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +# Use a common model that is likely to be available +MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct" + + +@pytest.fixture(scope="module") +def seed_oss_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) + + +@pytest.fixture +def seed_oss_tool_parser(seed_oss_tokenizer): + return SeedOssToolParser(seed_oss_tokenizer) + + +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_weather", + "description": "Get current temperature for a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": + "City and country e.g. Bogotá, Colombia" + }, + "unit": { + "type": "string", + "description": "this is the unit of temperature" + } + }, + "required": ["location"], + "additionalProperties": False + }, + "returns": { + "type": "object", + "properties": { + "temperature": { + "type": "number", + "description": "temperature in celsius" + } + }, + "required": ["temperature"], + "additionalProperties": False + }, + "strict": True + }), + ] + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + # Seed-OSS tool call will not generate id + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + assert actual_tool_call.function.name == expected_tool_call.function.name + assert actual_tool_call.function.arguments == expected_tool_call.function.arguments + + +def test_extract_tool_calls_no_tools(seed_oss_tool_parser): + model_output = "This is a test response without any tool calls" + extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "tool_call_0_thinking_budget", + "tool_call_512_thinkg_budget", + "tool_call_unlimited_thinking_budget", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], None), + ( + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" + """\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""", + ), + ( + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" + """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, ), + ), + type='function') + ], + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think>""", + ), + ], +) +def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output, + expected_tool_calls, expected_content): + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( + model_output, request=request) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_streaming_tool_calls_no_tools(seed_oss_tool_parser): + model_output = "This is a test response without any tool calls" + + result = seed_oss_tool_parser.extract_tool_calls_streaming( + previous_text="his is a test response", + current_text=model_output, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." + + +def stream_delta_message_generator( + seed_oss_tool_parser: SeedOssToolParser, + seed_oss_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None +) -> Generator[DeltaMessage, None, None]: + all_token_ids = seed_oss_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=seed_oss_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = seed_oss_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +@pytest.mark.parametrize( + ids=[ + "tool_call_0_thinking_budget", + "tool_call_512_thinkg_budget", + "tool_call_unlimited_thinking_budget", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" + ), + ( + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" + """\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""", + ), + ( + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" + """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, ), + ), + type='function') + ], + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think>""", + ), + ], +) +def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, + sample_tools, model_output, expected_tool_calls, + expected_content): + """Test incremental streaming behavior""" + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} # Track state per tool index + + for delta_message in stream_delta_message_generator( + seed_oss_tool_parser, seed_oss_tokenizer, model_output, request): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + # First chunk should have id, name, and type + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + # Should only be set once + assert tool_states[idx]["name"] is None + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + # Accumulate arguments incrementally + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify final content + assert other_content == expected_content + + # Verify we got all expected tool calls + assert len(tool_states) == len(expected_tool_calls) + + # Verify each tool call + for idx, expected_tool in enumerate(expected_tool_calls): + state = tool_states[idx] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == expected_tool.function.name + + # Parse accumulated arguments + arguments_str = state["arguments"] + assert arguments_str is not None + actual_args = json.loads(arguments_str) + expected_args = json.loads(expected_tool.function.arguments) + assert actual_args == expected_args diff --git a/tests/utils.py b/tests/utils.py index 18fcde949160e..9d2073f3c1036 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ import asyncio import copy import functools import importlib +import json import os import signal import subprocess @@ -13,6 +14,7 @@ import tempfile import time import warnings from contextlib import contextmanager, suppress +from multiprocessing import Process from pathlib import Path from typing import Any, Callable, Literal, Optional, Union @@ -76,6 +78,23 @@ VLLM_PATH = Path(__file__).parent.parent class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key + def _start_server(self, model: str, vllm_serve_args: list[str], + env_dict: Optional[dict[str, str]]) -> None: + """Subclasses override this method to customize server process launch + """ + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) + self.proc: subprocess.Popen = subprocess.Popen( + ["vllm", "serve", model, *vllm_serve_args], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + def __init__(self, model: str, vllm_serve_args: list[str], @@ -83,7 +102,8 @@ class RemoteOpenAIServer: env_dict: Optional[dict[str, str]] = None, seed: Optional[int] = 0, auto_port: bool = True, - max_wait_seconds: Optional[float] = None) -> None: + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: raise ValueError("You have manually specified the port " @@ -102,6 +122,12 @@ class RemoteOpenAIServer: vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + if override_hf_configs is not None: + vllm_serve_args = vllm_serve_args + [ + "--hf-overrides", + json.dumps(override_hf_configs) + ] + parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") subparsers = parser.add_subparsers(required=False, dest="subparser") @@ -128,18 +154,7 @@ class RemoteOpenAIServer: model_loader = get_model_loader(load_config) model_loader.download_model(model_config) - env = os.environ.copy() - # the current process might initialize cuda, - # to be safe, we should use spawn method - env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' - if env_dict is not None: - env.update(env_dict) - self.proc = subprocess.Popen( - ["vllm", "serve", model, *vllm_serve_args], - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) + self._start_server(model, vllm_serve_args, env_dict) max_wait_seconds = max_wait_seconds or 240 self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) @@ -155,6 +170,10 @@ class RemoteOpenAIServer: # force kill if needed self.proc.kill() + def _poll(self) -> Optional[int]: + """Subclasses override this method to customize process polling""" + return self.proc.poll() + def _wait_for_server(self, *, url: str, timeout: float): # run health check start = time.time() @@ -169,7 +188,7 @@ class RemoteOpenAIServer: # which means the server is not ready yet. # the stack trace is not useful, so we suppress it # by using `raise from None`. - result = self.proc.poll() + result = self._poll() if result is not None and result != 0: raise RuntimeError("Server exited unexpectedly.") from None @@ -205,6 +224,48 @@ class RemoteOpenAIServer: **kwargs) +class RemoteOpenAIServerCustom(RemoteOpenAIServer): + """Launch test server with custom child process""" + + def _start_server(self, model: str, vllm_serve_args: list[str], + env_dict: Optional[dict[str, str]]) -> None: + self.proc: Process = Process( + target=self.child_process_fxn, + args=(env_dict, model, + vllm_serve_args)) # type: ignore[assignment] + self.proc.start() + + def __init__(self, + model: str, + vllm_serve_args: list[str], + child_process_fxn: Callable[ + [Optional[dict[str, str]], str, list[str]], None], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None) -> None: + """Store custom child process function then invoke superclass + constructor which will indirectly launch it.""" + self.child_process_fxn = child_process_fxn + super().__init__(model=model, + vllm_serve_args=vllm_serve_args, + env_dict=env_dict, + seed=seed, + auto_port=auto_port, + max_wait_seconds=max_wait_seconds) + + def _poll(self) -> Optional[int]: + return self.proc.exitcode + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + self.proc.join(8) + if self.proc.is_alive(): + # force kill if needed + self.proc.kill() + + def _test_completion( client: openai.OpenAI, model: str, @@ -635,9 +696,12 @@ def multi_process_parallel( os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1" ray.init( runtime_env={ - "working_dir": VLLM_PATH, - "excludes": - ["build", ".git", "cmake-build-*", "shellcheck", "dist"] + "working_dir": + VLLM_PATH, + "excludes": [ + "build", ".git", "cmake-build-*", "shellcheck", "dist", + "ep_kernels_workspace" + ] }) distributed_init_port = get_open_port() diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 084d82dee11b3..04195ea0cf92e 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -5,13 +5,17 @@ import asyncio import hashlib import json +import os import pickle import socket +import tempfile from collections.abc import AsyncIterator +from pathlib import Path from unittest.mock import patch import pytest import torch +import yaml import zmq from transformers import AutoTokenizer from vllm_test_utils.monitor import monitor @@ -991,3 +995,40 @@ def test_current_stream_multithread(): child_thread.join(timeout=5) if child_thread.is_alive(): pytest.fail("Child thread failed to exit properly") + + +def test_load_config_file(tmp_path): + # Define the configuration data + config_data = { + "enable-logging": True, + "list-arg": ["item1", "item2"], + "port": 12323, + "tensor-parallel-size": 4 + } + + # Write the configuration data to a temporary YAML file + config_file_path = tmp_path / "config.yaml" + with open(config_file_path, "w") as config_file: + yaml.dump(config_data, config_file) + + # Initialize the parser + parser = FlexibleArgumentParser() + + # Call the function with the temporary file path + processed_args = parser.load_config_file(str(config_file_path)) + + # Expected output + expected_args = [ + "--enable-logging", + "--list-arg", + "item1", + "item2", + "--port", + "12323", + "--tensor-parallel-size", + "4", + ] + + # Assert that the processed arguments match the expected output + assert processed_args == expected_args + os.remove(str(config_file_path)) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index ac08b9052cd80..e4c07aae0ebed 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -10,14 +10,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, set_kv_cache_layout) from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, - _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN + _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN, + "FLEX_ATTENTION_SLOW" ] # Remove flashinfer from the list if it's not available @@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache( common_attn_metadata: CommonAttentionMetadata, randomize_blocks: bool = True) -> torch.Tensor: """Create and prepopulate a KV cache with context data. - + Args: k_contexts: List of key context tensors for each sequence v_contexts: List of value context tensors for each sequence @@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache( device: Device to create the cache on num_blocks: Total number of blocks in the cache block_table: Block table tensor to populate - randomize_blocks: Whether to randomly permute blocks + randomize_blocks: Whether to randomly permute blocks or use sequential order - + Returns: Tuple of (kv_cache, updated_block_table) """ @@ -150,15 +151,15 @@ def create_and_prepopulate_kv_cache( # Permute the context blocks (excluding block 0 which is null) if randomize_blocks: - perm = torch.randperm( - blocks_end - 1) + 1 # Random permutation starting from block 1 + # Random permutation starting from block 1 + perm = torch.randperm(blocks_end - 1) + 1 else: - perm = torch.arange( - 1, blocks_end) # Sequential order starting from block 1 + # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - inv_perm[1:] = torch.argsort( - perm) + 1 # Add 1 to account for starting from block 1 + # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] # Construct the right block table @@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, kv_cache: torch.Tensor) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" - builder_cls, impl_cls = get_attention_backend(backend) + # Handle special case for FLEX_ATTENTION_SLOW + actual_backend = backend + + use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") + if backend == "FLEX_ATTENTION_SLOW": + actual_backend = _Backend.FLEX_ATTENTION + use_direct_block_mask = False + + builder_cls, impl_cls = get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if backend == _Backend.FLASHINFER_VLLM_V1: + if actual_backend == _Backend.FLASHINFER_VLLM_V1: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -239,6 +248,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + if actual_backend == _Backend.FLEX_ATTENTION: + builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -281,7 +292,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, @pytest.mark.parametrize("batch_spec_name", [ "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium" + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" ]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_backend_correctness(batch_spec_name: str, model: str): @@ -302,7 +314,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): """ batch_spec = BATCH_SPECS[batch_spec_name] vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens)) + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=8192) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -451,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-3 - if backend_name == _Backend.FLEX_ATTENTION: - atol = 5e-1 # TODO: figure out why flex_attention has such large - # numerical differences for medium_decode, medium_prefill, - # mixed_medium - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() max_rel_diff = torch.max( torch.abs(backend_output - sdpa_output) / @@ -465,12 +473,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol=rtol, atol=atol) - if not all_close: - print(f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") - print(f"[{backend_name}] output: {backend_output}") - print(f"[{backend_name}] SDPA baseline: {sdpa_output}") - assert all_close, ( f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") \ No newline at end of file diff --git a/tests/v1/attention/test_attention_backends_selection.py b/tests/v1/attention/test_attention_backends_selection.py new file mode 100644 index 0000000000000..59e5628149468 --- /dev/null +++ b/tests/v1/attention/test_attention_backends_selection.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mamba attention backend selectors.""" + +from types import SimpleNamespace + +import pytest + +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.models.minimax_text_01 import ( + MiniMaxText01LinearAttention) +from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + + +@pytest.mark.parametrize( + "layer_class, init_kwargs, expected_backend, expected_mamba_type", [ + ( + MambaMixer, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + time_step_rank=8, + use_conv_bias=True, + use_bias=False, + use_rms_norm=True, + ), + Mamba1AttentionBackend, + "mamba1", + ), + ( + MambaMixer2, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + use_conv_bias=True, + use_bias=False, + n_groups=1, + num_heads=8, + head_dim=32, + ), + Mamba2AttentionBackend, + "mamba2", + ), + ( + MiniMaxText01LinearAttention, + dict( + hidden_size=128, + hidden_inner_size=256, + num_heads=8, + head_dim=32, + max_position=2048, + block_size=64, + num_hidden_layer=12, + layer_idx=0, + linear_layer_idx=0, + ), + LinearAttentionBackend, + "linear_attention", + ), + ( + ShortConv, + dict( + config=SimpleNamespace(conv_L_cache=32, conv_bias=True), + dim=128, + layer_idx=0, + ), + ShortConvAttentionBackend, + "short_conv", + ), + ]) +def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs, + expected_backend, expected_mamba_type): + """Test that Mamba-like layers return the correct attention backend.""" + layer = layer_class(**init_kwargs) + + backend_class = layer.get_attn_backend() + assert backend_class is expected_backend + assert layer.mamba_type == expected_mamba_type + + +@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [ + (MambaMixer, Mamba1AttentionBackend, "mamba1"), + (MambaMixer2, Mamba2AttentionBackend, "mamba2"), + (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"), + (ShortConv, ShortConvAttentionBackend, "short_conv"), +]) +def test_mamba_layers_have_unified_interface(layer_class, expected_backend, + expected_mamba_type): + """Test that all Mamba layers have the unified get_attn_backend + interface.""" + assert hasattr(layer_class, 'get_attn_backend'), ( + f"{layer_class.__name__} should have get_attn_backend method") + assert hasattr(layer_class, 'mamba_type'), ( + f"{layer_class.__name__} should have mamba_type property") diff --git a/tests/v1/attention/test_mamba_selectors.py b/tests/v1/attention/test_mamba_selectors.py deleted file mode 100644 index 4245b50c71310..0000000000000 --- a/tests/v1/attention/test_mamba_selectors.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for mamba attention backend selectors.""" - -import pytest - -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend - - -@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"], - argvalues=[("mamba2", Mamba2AttentionBackend)]) -def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend): - backend_class = get_mamba_attn_backend(mamba_type) - - assert backend_class is expected_backend - - -def test_get_mamba_attn_backend_unsupported(): - unsupported_types = ["mamba", ""] - - for mamba_type in unsupported_types: - err_message = f"Mamba Attention type {mamba_type} is not supported yet." - with pytest.raises(NotImplementedError, match=err_message): - get_mamba_attn_backend(mamba_type) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py new file mode 100644 index 0000000000000..24070358799ef --- /dev/null +++ b/tests/v1/attention/test_mla_backends.py @@ -0,0 +1,522 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 MLA backends without GPUModelRunner dependency.""" + +import pytest +import torch + +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + +BACKENDS_TO_TEST = [ + _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, + _Backend.TRITON_MLA_VLLM_V1 +] + +# Remove CUTLASS_MLA from the list if not using sm100 +if not torch.cuda.is_available() or torch.cuda.get_device_properties( + 0).major < 10: + BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) + +torch.manual_seed(42) + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": + BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": + BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": + BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": + BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), + "medium_prefill": + BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), + "mixed_medium": + BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], + query_lens=[1, 1, 1, 7, 7, 7]), + "large_decode": + BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": + BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": + BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": + BatchSpec(seq_lens=[1024], query_lens=[64]), +} + + +def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, + device: torch.device, + num_blocks: int = 100) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.head_size, # latent dimension + dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), + device=device, + ) + return kv_cache + + +def create_and_prepopulate_kv_cache( + kv_c_contexts: list[torch.Tensor], + k_pe_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True) -> torch.Tensor: + """Create and prepopulate an MLA KV cache with context data. + + Args: + kv_c_contexts: List of latent KV context tensors for each sequence + k_pe_contexts: List of key positional embedding context tensors + for each sequence + block_size: Size of each block + num_kv_heads: Number of KV heads (should be 1 for MLA) + head_size: Size of each head (latent dimension) + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + common_attn_metadata: Common attention metadata + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + MLA KV cache tensor + """ + batch_size = len(kv_c_contexts) + seq_lens = common_attn_metadata.seq_lens_cpu + query_lens = common_attn_metadata.query_start_loc_cpu[ + 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + context_lens = common_attn_metadata.num_computed_tokens_cpu + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create MLA KV cache: (num_blocks, block_size, head_size) + kv_cache = torch.empty(num_blocks, + block_size, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(-1, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + start = start_block_idx * block_size + end = start + kv_context.shape[0] + kv_cache_flat[start:end, ...] = kv_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + perm = torch.randperm( + blocks_end - 1) + 1 # Random permutation starting from block 1 + else: + perm = torch.arange( + 1, blocks_end) # Sequential order starting from block 1 + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort( + perm) + 1 # Add 1 to account for starting from block 1 + kv_cache[1:blocks_end, ...] = kv_cache[perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, + block_indices] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + + +def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, + layer_names: list[str], vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, kv_c: torch.Tensor, + k_pe: torch.Tensor, kv_cache: torch.Tensor, + kv_lora_rank: int, qk_nope_head_dim: int, + qk_rope_head_dim: int, v_head_dim: int, + mock_kv_b_proj) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + builder_cls, impl_cls = get_attention_backend(backend) + + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate MLA implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + ) + + # Process weights to create W_UK_T and W_UV attributes needed by MLA + act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + impl.process_weights_after_loading(act_dtype) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + num_tokens = query.shape[0] + output = torch.empty(num_tokens, + num_heads * v_head_dim, + dtype=query.dtype, + device=query.device) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward(mock_layer, + query, + kv_c, + k_pe, + kv_cache, + attn_metadata, + output=output) + + return output + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" +]) +@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) +def test_backend_correctness(dist_init, batch_spec_name: str, model: str): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + """ + batch_spec = BATCH_SPECS[batch_spec_name] + vllm_config = create_vllm_config(model_name=model, + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=2048) + device = torch.device("cuda:0") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + kv_lora_rank = 512 + qk_rope_head_dim = 64 + qk_nope_head_dim = 128 + v_head_dim = 128 + total_head_size = kv_lora_rank + qk_rope_head_dim + assert kv_lora_rank + qk_rope_head_dim == head_size, \ + f"MLA dimensions don't match: {total_head_size} != {head_size}" + scale = 1.0 / (total_head_size**0.5) + + # 2. Generate data and compute SDPA reference output for MLA + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] + all_sdpa_outputs = [] + kv_c_contexts, k_pe_contexts = [], [] + + # Create shared MLA weight matrices for consistency across all sequences + W_UK = torch.randn(kv_lora_rank, + num_q_heads, + qk_nope_head_dim, + dtype=dtype, + device=device) + W_UV = torch.randn(kv_lora_rank, + num_q_heads, + v_head_dim, + dtype=dtype, + device=device) + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate MLA tensors + # Q has both nope and rope components: + # [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim] + q_c = torch.randn(q_len, + num_q_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device) + + # KV_C (latent K/V): [s_len, kv_lora_rank] + kv_c_full = torch.randn(s_len, + kv_lora_rank, + dtype=dtype, + device=device) + + # K_PE (rope component): [s_len, 1, qk_rope_head_dim] + k_pe_full = torch.randn(s_len, + 1, + qk_rope_head_dim, + dtype=dtype, + device=device) + + # Determine if this is decode (single token) + # or prefill (multiple tokens) + is_decode = q_len == 1 + + # Split q into nope and rope components + q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + + if is_decode: + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, + W_UK) # [1, num_heads, kv_lora_rank] + + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) + sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze( + 0) # [1, num_heads, kv_lora_rank] + + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) + sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + else: + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, + kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split( + [qk_nope_head_dim, v_head_dim], dim=-1) + + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], + dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, + s_len, + dtype=torch.bool, + device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + attn_mask=attn_mask, + scale=scale) + sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) + sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + + all_sdpa_outputs.append(sdpa_out_i) + + # Inputs for vLLM MLA backends are just the new tokens + all_q_vllm.append(q_c) + all_kv_c_vllm.append(kv_c_full[context_len:]) # New kv_c tokens + all_k_pe_vllm.append(k_pe_full[context_len:]) # New k_pe tokens + + # Contextual K/V data used to populate the paged cache (MLA format) + kv_c_contexts.append(kv_c_full[:context_len]) + k_pe_contexts.append(k_pe_full[:context_len]) + + # Concatenate all sequences (no reordering needed) + query_vllm = torch.cat(all_q_vllm, dim=0) + kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) + k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) + sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + + # Create mock kv_b_proj using the same weights as reference implementation + from vllm.model_executor.layers.linear import ColumnParallelLinear + mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, + output_size=num_q_heads * + (qk_nope_head_dim + v_head_dim), + bias=False).to(device=device, + dtype=dtype) + + # Set the mock weights to match our reference implementation + # Reshape W_UK and W_UV to match the expected kv_b_proj format + # [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim] + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) + + # Create metadata using original batch spec + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True) + + # 4. Run vLLM backends and compare + for backend_name in BACKENDS_TO_TEST: + backend_output = run_attention_backend( + backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, + common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, + kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, + mock_kv_b_proj) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != " + f"SDPA shape {sdpa_output.shape}") + assert backend_output.dtype == sdpa_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != " + f"SDPA dtype {sdpa_output.dtype}") + + assert torch.isfinite(backend_output).all(), ( + f"[{backend_name}] produced non-finite values") + + # Check numerical similarity + rtol = 1e-2 + atol = 5e-1 + + max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_rel_diff = torch.max( + torch.abs(backend_output - sdpa_output) / + torch.abs(sdpa_output)).item() + all_close = torch.allclose(backend_output, + sdpa_output, + rtol=rtol, + atol=atol) + + assert all_close, ( + f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index a4e38eb32f6a1..6a08cdc56f736 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -58,6 +58,7 @@ def create_common_attn_metadata( dtype=torch.int32, device=device) seq_lens_cpu = seq_lens.cpu() + max_seq_len = int(seq_lens_cpu.max()) # Create computed tokens (context length for each sequence) context_lens = [ @@ -101,6 +102,7 @@ def create_common_attn_metadata( num_reqs=batch_spec.batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, causal=True, @@ -133,6 +135,12 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", _Backend.XFORMERS_VLLM_V1: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", + _Backend.CUTLASS_MLA: + "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", + _Backend.FLASHMLA_VLLM_V1: + "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.TRITON_MLA_VLLM_V1: + "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } if backend_name not in backend_map: @@ -165,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", tensor_parallel_size: int = 1, max_model_len: int = 1024, dtype: Union[ModelDType, torch.dtype] = "auto", + num_gpu_blocks: int = 1000, block_size: int = 16, max_num_seqs: int = 256, max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, add_mock_model_methods: bool = True) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" @@ -187,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ) # Set cache blocks for testing # (these may be set during initialization normally) - cache_config.num_gpu_blocks = 1000 + cache_config.num_gpu_blocks = num_gpu_blocks cache_config.num_cpu_blocks = 0 parallel_config = ParallelConfig( @@ -196,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, ) device_config = DeviceConfig() diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 3ccefbd81cab5..c153e38fe3df3 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -7,6 +7,7 @@ import pytest from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import RequestStatus +from vllm.v1.utils import ConstantList from .utils import create_requests, create_scheduler @@ -21,7 +22,6 @@ def _make_model_runner_output( for i, req_id in enumerate(req_ids) }, sampled_token_ids=[[i] for i in range(len(req_ids))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -140,7 +140,8 @@ def test_prefix_caching_for_prefill_dedup(): requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, max_tokens=3, - same_prompt=True) + same_prompt=True, + block_size=BLOCK_SIZE) requests_copy = requests.copy() # Two requests with the same prompt. @@ -188,7 +189,8 @@ def test_prefix_caching_for_multi_turn(): block_size=BLOCK_SIZE) requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, - max_tokens=num_output_tokens) + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE) for req in requests: scheduler.add_request(req) @@ -208,14 +210,19 @@ def test_prefix_caching_for_multi_turn(): # Create next-turn requests whose prompts are the full output of the # previous turn. - next_turn_requests = create_requests( - num_requests=5, - num_tokens=num_prompt_tokens + num_output_tokens, - max_tokens=num_output_tokens, - ) + next_turn_requests = create_requests(num_requests=5, + num_tokens=num_prompt_tokens + + num_output_tokens, + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE) for i, req in enumerate(next_turn_requests): req.prompt_token_ids = (requests[i].prompt_token_ids + list(requests[i].output_token_ids)) + req._all_token_ids = req.prompt_token_ids.copy() + req.all_token_ids = ConstantList(req._all_token_ids) + req.block_hashes = [] + req.block_hashes = req.get_hash_new_full_blocks() + # Schedule the next-turn requests. for req in next_turn_requests: scheduler.add_request(req) diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py new file mode 100644 index 0000000000000..ae5b751f45a4b --- /dev/null +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.core.encoder_cache_manager import EncoderCacheManager + + +# ------------------ Mock Classes ------------------ # +class MockRequest: + + def __init__(self, request_id, mm_hashes, token_counts): + self.request_id = request_id + self.mm_hashes = mm_hashes + self._token_counts = token_counts + + def get_num_encoder_tokens(self, input_id: int) -> int: + return self._token_counts[input_id] + + +# ------------------ Unit Tests ------------------ # +def test_basic_allocate_and_reuse(): + cache = EncoderCacheManager(cache_size=10) + req = MockRequest("r1", ["imgA"], [4]) + + assert not cache.check_and_update_cache(req, 0) + assert cache.can_allocate(req, 0, int(1e9), 0) + + cache.allocate(req, 0) + + assert cache.check_and_update_cache(req, 0) + assert "r1" in cache.cached["imgA"] + assert cache.num_free_slots == 6 + + # Free twice to bring refcount to 0. + cache.free_encoder_input(req, 0) + cache.free_encoder_input(req, 0) + + assert not cache.cached["imgA"] + assert "imgA" in cache.freeable + assert cache.num_freeable_slots == 10 + assert cache.num_free_slots == 6 + + +def test_freeing_decreases_refcount_and_moves_to_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req2", ["img3"], [5]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert len(manager.cached["img3"]) == 1 + + manager.free_encoder_input(req, 0) + + assert not manager.cached["img3"] + assert "img3" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_free_request_frees_all_inputs(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req3", ["a", "b"], [2, 3]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert manager.can_allocate(req, 1, int(1e9), 0) + manager.allocate(req, 1) + + assert len(manager.cached["a"]) == 1 + assert len(manager.cached["b"]) == 1 + + manager.free(req) + + assert not manager.cached["a"] + assert not manager.cached["b"] + assert "a" in manager.freeable + assert "b" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_eviction_when_cache_is_full(): + manager = EncoderCacheManager(cache_size=10) + + req1 = MockRequest("req1", ["x"], [6]) + req2 = MockRequest("req2", ["y"], [5]) + + assert manager.can_allocate(req1, 0, int(1e9), 0) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + # 'x' should have been evicted. + assert "x" not in manager.cached + assert "x" in manager.get_freed_mm_hashes() + + +def test_get_cached_input_ids(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert manager.can_allocate(req, 2, int(1e9), 0) + manager.allocate(req, 2) + + cached_ids = manager.get_cached_input_ids(req) + assert cached_ids == {0, 2} + + +def test_has_cache_restores_from_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqY", ["imgZ"], [4]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + manager.free_encoder_input(req, 0) + + # Should restore from freeable. + assert manager.check_and_update_cache(req, 0) + assert len(manager.cached["imgZ"]) == 1 + assert "imgZ" not in manager.freeable + assert manager.num_freeable_slots == 6 + + +def test_get_freed_mm_hashes_clears_freed_list(): + manager = EncoderCacheManager(cache_size=10) + req1 = MockRequest("reqA", ["a"], [5]) + req2 = MockRequest("reqB", ["b"], [6]) + + assert manager.can_allocate(req1, 0, int(1e9), 0) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + # Should trigger eviction of 'a'. + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + freed = manager.get_freed_mm_hashes() + assert "a" in freed + assert manager.get_freed_mm_hashes() == [] + + +def test_schedule_request_multi_images_respect_space_limit(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqA", ["a", "b"], [5, 6]) + compute_budget = 100 + + num_tokens_to_schedule = 0 + assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) + num_tokens_to_schedule += req.get_num_encoder_tokens(0) + compute_budget -= req.get_num_encoder_tokens(0) + + assert not manager.can_allocate(req, 1, compute_budget, + num_tokens_to_schedule) + + +def test_schedule_request_multi_images_respect_compute_limit(): + manager = EncoderCacheManager(cache_size=100) + req = MockRequest("reqA", ["a", "b"], [5, 6]) + compute_budget = 10 + num_tokens_to_schedule = 0 + assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) + num_tokens_to_schedule += req.get_num_encoder_tokens(0) + compute_budget -= req.get_num_encoder_tokens(0) + + assert not manager.can_allocate(req, 1, compute_budget, + num_tokens_to_schedule) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 182ea2b2345c4..47c74aff1e753 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib -from typing import Optional +from typing import Callable, Optional import pytest import torch from vllm.config import ModelConfig, SchedulerConfig, VllmConfig -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -19,7 +17,7 @@ from vllm.v1.core.kv_cache_utils import ( FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, get_kv_cache_config, get_max_concurrency_for_kv_cache_config, - hash_block_tokens, hash_request_tokens, init_none_hash, + get_request_block_hasher, hash_block_tokens, init_none_hash, is_kv_cache_type_uniform, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, @@ -33,6 +31,8 @@ from vllm.v1.request import Request def make_request( request_id: str, prompt_token_ids: list[int], + block_size: int = 3, + hash_fn: Callable = hash, mm_positions: Optional[list[PlaceholderRange]] = None, mm_hashes: Optional[list[str]] = None, cache_salt: Optional[str] = None, @@ -40,27 +40,20 @@ def make_request( if mm_positions is None: mm_kwargs = None else: - mm_elem = MultiModalFieldElem( - modality="dummy_m", - key="dummy_k", - data=None, - field=MultiModalBatchedField(), - ) - mm_item = MultiModalKwargsItem.from_elems([mm_elem]) + mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_positions) - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_kwargs=mm_kwargs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - ) + return Request(request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_kwargs=mm_kwargs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn)) def new_kv_cache_spec(block_size=16, @@ -428,12 +421,14 @@ def test_hash_block_tokens(hash_fn): @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) -def test_hash_request_tokens(hash_fn): +def test_request_block_hasher(hash_fn): import vllm.v1.core.kv_cache_utils init_none_hash(hash_fn) request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=[ PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=3, length=3), @@ -441,9 +436,7 @@ def test_hash_request_tokens(hash_fn): mm_hashes=["hash1", "hash2"], ) - block_size = 3 - block_hashes = hash_request_tokens(hash_fn, block_size, request) - + block_hashes = request.block_hashes assert len(block_hashes) == 2 assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) @@ -464,6 +457,8 @@ def test_hash_tokens_different_mm_input(hash_fn): request1 = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=[ PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=3, length=3), @@ -479,9 +474,8 @@ def test_hash_tokens_different_mm_input(hash_fn): ], mm_hashes=["hash3", "hash2"], ) - block_size = 3 - block_hashes1 = hash_request_tokens(hash_fn, block_size, request1) - block_hashes2 = hash_request_tokens(hash_fn, block_size, request2) + block_hashes1 = request1.block_hashes + block_hashes2 = request2.block_hashes assert block_hashes1[0] != block_hashes2[0] assert block_hashes1[1] != block_hashes2[1] @@ -493,12 +487,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=None, mm_hashes=None, ) - block_size = 3 - block_hashes = hash_request_tokens(hash_fn, block_size, request) + block_hashes = request.block_hashes assert len(block_hashes) == 2 assert block_hashes[0].token_ids == (0, 1, 2) @@ -858,6 +853,7 @@ def test_allocate_with_lookahead(): request = make_request( request_id="0", prompt_token_ids=[], + block_size=block_size, mm_positions=None, mm_hashes=None, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 87acdef220133..89824768ed909 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -3,22 +3,21 @@ """Compare the with and without prefix caching.""" import copy -from typing import Optional +from typing import Callable, Optional import pytest import torch from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, hash_block_tokens, - init_none_hash) + KVCacheBlock, + get_request_block_hasher, + hash_block_tokens, init_none_hash) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -26,6 +25,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, def make_request( request_id: str, prompt_token_ids: list[int], + block_size: int, + hash_fn: Callable, mm_positions: Optional[list[PlaceholderRange]] = None, mm_hashes: Optional[list[str]] = None, prompt_logprobs: Optional[int] = None, @@ -34,28 +35,21 @@ def make_request( if mm_positions is None: mm_kwargs = None else: - mm_elem = MultiModalFieldElem( - modality="dummy_m", - key="dummy_k", - data=None, - field=MultiModalBatchedField(), - ) - mm_item = MultiModalKwargsItem.from_elems([mm_elem]) + mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_positions) - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_kwargs=mm_kwargs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17, - prompt_logprobs=prompt_logprobs), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - ) + return Request(request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_kwargs=mm_kwargs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams( + max_tokens=17, prompt_logprobs=prompt_logprobs), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn)) def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @@ -105,11 +99,11 @@ def make_kv_cache_config_hybrid_model(block_size: int, @pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) def test_prefill(hash_algo): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, - caching_hash_algo=hash_algo, ) # choose the hash function according to the parameter @@ -123,9 +117,9 @@ def test_prefill(hash_algo): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -152,9 +146,10 @@ def test_prefill(hash_algo): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -187,9 +182,10 @@ def test_prefill(hash_algo): # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 - req2 = make_request("2", common_token_ids + unique_token_ids) + req2 = make_request("2", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 3 + assert len(req2.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -208,7 +204,7 @@ def test_prefill(hash_algo): manager.free(req2) # Cache miss and eviction. - req3 = make_request("3", [99] * (16 * 10)) + req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -242,9 +238,9 @@ def test_prefill_hybrid_model(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -274,9 +270,10 @@ def test_prefill_hybrid_model(): # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 @@ -290,7 +287,7 @@ def test_prefill_hybrid_model(): if block != manager.block_pool.null_block: assert block.ref_cnt == 2 - block_hashes = manager.req_to_block_hashes[req1.request_id] + block_hashes = req1.block_hashes manager.free(req0) manager.free(req1) @@ -300,12 +297,13 @@ def test_prefill_hybrid_model(): def test_partial_request_hit(request_id: str, hash_to_evict: list[BlockHashWithGroupId], expect_hit_length: int): - req = make_request(request_id, common_token_ids + unique_token_ids) + req = make_request(request_id, common_token_ids + unique_token_ids, + block_size, hash) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block.pop( hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert len(manager.req_to_block_hashes[req.request_id]) == 3 + assert len(req.block_hashes) == 3 assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size @@ -364,8 +362,9 @@ def test_prefill_plp(): 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks ''' + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -380,9 +379,13 @@ def test_prefill_plp(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids, prompt_logprobs=5) + req0 = make_request("0", + all_token_ids, + block_size, + hash_fn, + prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 0 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -411,9 +414,10 @@ def test_prefill_plp(): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -447,9 +451,11 @@ def test_prefill_plp(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids, + block_size, + hash_fn, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 0 + assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, @@ -469,8 +475,9 @@ def test_prefill_plp(): def test_decode(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -481,7 +488,8 @@ def test_decode(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids) + req0 = make_request("0", common_token_ids + unique_token_ids, block_size, + hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -518,14 +526,15 @@ def test_decode(): def test_evict(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) last_token_id = 5 * 16 + 7 - req0 = make_request("0", list(range(last_token_id))) + req0 = make_request("0", list(range(last_token_id)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -536,7 +545,8 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, - last_token_id + 3 * 16))) + last_token_id + 3 * 16)), block_size, + hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -558,7 +568,7 @@ def test_evict(): ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. - req2 = make_request("2", list(range(2 * 16 + 3))) + req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 @@ -583,7 +593,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 - req = make_request("0", list(range(num_tokens))) + req = make_request("0", list(range(num_tokens)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -597,7 +607,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. - req = make_request("1", list(range(num_tokens - 1))) + req = make_request("1", list(range(num_tokens - 1)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -624,7 +634,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 - req0 = make_request("0", list(range(num_tokens))) + req0 = make_request("0", list(range(num_tokens)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -635,7 +645,8 @@ def test_computed_blocks_not_evicted(): assert blocks.blocks[0][0].block_id == 1 # Allocate another block. - req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) + req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), + block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -651,7 +662,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. - req2 = make_request("2", list(range(num_tokens * 2))) + req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 @@ -675,7 +686,8 @@ def test_basic_prefix_caching_disabled(): enable_caching=False, ) - req1 = make_request("1", list(range(10))) # 2 blocks and some more + req1 = make_request("1", list(range(10)), block_size, + hash) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] @@ -689,7 +701,8 @@ def test_basic_prefix_caching_disabled(): manager.free(req1) # No caching. - req2 = make_request("2", list(range(16))) # shared prefix + req2 = make_request("2", list(range(16)), block_size, + hash) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -699,7 +712,7 @@ def test_basic_prefix_caching_disabled(): assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. - req3 = make_request("3", list(range(4))) + req3 = make_request("3", list(range(4)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -727,20 +740,17 @@ def test_cache_blocks(hash_fn): # Block 1: [4, 5, 6, 7] # Block 2: [8, 9, 10, 11] # Block 3: [12, 13] - req = make_request("0", list(range(14))) + req = make_request("0", list(range(14)), block_size, hash_fn) # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, - hash_fn=hash_fn, kv_cache_group_id=0, ) @@ -752,11 +762,9 @@ def test_cache_blocks(hash_fn): block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=2, num_full_blocks=3, block_size=block_size, - hash_fn=hash_fn, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 3 @@ -775,23 +783,20 @@ def test_cache_blocks_multi_group(): # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] - req = make_request("0", list(range(14))) + req = make_request("0", list(range(14)), block_size, hash) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, - hash_fn=hash, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 2 - assert len(block_hashes) == 2 + assert len(req.block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Cache the blocks for group 1. @@ -799,38 +804,36 @@ def test_cache_blocks_multi_group(): block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=3, block_size=block_size, - hash_fn=hash, kv_cache_group_id=1, ) assert len(block_pool.cached_block_hash_to_block) == 5 - assert len(block_hashes) == 3 + assert len(req.block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Block hash 0: hit for group 0 and 1 # Block hash 1: hit for group 0 and 1 # Block hash 2: hit for group 1 - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0]) is None - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0, 1]) is None @@ -838,8 +841,9 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -865,6 +869,8 @@ def test_mm_prefix_caching(): mm_hashes = common_mm_hashes + ["ccc"] req0 = make_request("0", all_token_ids, + block_size, + hash, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) @@ -872,7 +878,7 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("aaa", ) assert block_hashes[1].extra_keys == ("aaa", "bbb") @@ -905,6 +911,8 @@ def test_mm_prefix_caching(): mm_hashes = common_mm_hashes + ["ccc"] req1 = make_request("1", all_token_ids, + block_size, + hash, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) @@ -927,13 +935,13 @@ def test_cache_key_salting(): # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 - req0 = make_request("0", token_ids, cache_salt="salt1") + req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt1", ) assert block_hashes[1].extra_keys is None @@ -959,7 +967,7 @@ def test_cache_key_salting(): # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 - req1 = make_request("1", token_ids, cache_salt="salt1") + req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 @@ -967,11 +975,11 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 - req2 = make_request("2", token_ids, cache_salt="salt2") + req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req2.request_id] + block_hashes = req2.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt2", ) @@ -992,7 +1000,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] - req0 = make_request("0", common_token_ids) + req0 = make_request("0", common_token_ids, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1003,7 +1011,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | - req1 = make_request("1", common_token_ids * 2) + req1 = make_request("1", common_token_ids * 2, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 @@ -1020,19 +1028,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | - req2 = make_request("2", [7] * block_size * 2) + req2 = make_request("2", [7] * block_size * 2, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks[0]) * 16, + len(computed_blocks.blocks[0]) * block_size, computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 - req3 = make_request("3", common_token_ids * 3) + req3 = make_request("3", common_token_ids * 3, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 @@ -1047,8 +1055,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_reset_prefix_cache(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -1056,15 +1065,15 @@ def test_reset_prefix_cache(): full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash) blocks = manager.allocate_slots(req0, 55) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req1 = make_request("1", all_token_ids) + req1 = make_request("1", all_token_ids, block_size, hash) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, @@ -1086,8 +1095,9 @@ def test_reset_prefix_cache(): def test_prefix_cache_stats_disabled(): """Test that prefix_cache_stats is None when log_stats is False.""" + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, log_stats=False, # Disable logging stats @@ -1095,7 +1105,7 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. - req = make_request("0", list(range(16))) + req = make_request("0", list(range(16)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1192,7 +1202,7 @@ def test_kv_cache_events(blocks_to_cache: int): ) num_tokens = block_size * blocks_to_cache - req0 = make_request("0", list(range(num_tokens))) + req0 = make_request("0", list(range(num_tokens)), block_size, hash) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() @@ -1208,7 +1218,7 @@ def test_kv_cache_events(blocks_to_cache: int): # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) - req1 = make_request("1", list(range(num_tokens))) + req1 = make_request("1", list(range(num_tokens)), block_size, hash) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() @@ -1242,7 +1252,7 @@ def test_eagle_enabled_removes_last_block(): # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) - req = make_request("divisible_request", token_ids) + req = make_request("divisible_request", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1252,7 +1262,7 @@ def test_eagle_enabled_removes_last_block(): manager.free(req) # New request with same tokens + Eagle enabled - req_eagle = make_request("eagle_divisible", token_ids) + req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: @@ -1273,7 +1283,7 @@ def test_eagle_with_partial_blocks(): ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids) + req = make_request("partial_block_test", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1283,7 +1293,7 @@ def test_eagle_with_partial_blocks(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids) + req_eagle = make_request("partial_eagle", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1314,7 +1324,7 @@ def test_eagle_with_sliding_window(): # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids) + req = make_request("partial_block_test", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1322,12 +1332,12 @@ def test_eagle_with_sliding_window(): len(computed_blocks.blocks[0]) * 16, computed_blocks) # record the block hash of the first block in the request for later use - block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] + block_hash_first_block = req.block_hashes[0] assert block_hash_first_block is not None manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids) + req_eagle = make_request("partial_eagle", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1340,7 +1350,8 @@ def test_eagle_with_sliding_window(): BlockHashWithGroupId(block_hash_first_block, 0)) # New request - req_after_evict = make_request("partial_eagle_after_evict", token_ids) + req_after_evict = make_request("partial_eagle_after_evict", token_ids, + block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 1c7dd0ca90b7e..07d7c12a4f5ef 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -8,15 +8,13 @@ import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output.request import StructuredOutputRequest @@ -160,7 +158,6 @@ def test_schedule_partial_requests(): # Only the first request has a sampled token id because # the rest requests are still being prefilled. sampled_token_ids=[[0], [], []], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -211,7 +208,6 @@ def test_no_mm_input_chunking(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -275,7 +271,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -300,7 +295,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -344,7 +338,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None) @@ -357,7 +351,6 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -398,7 +391,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -411,7 +404,6 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -451,7 +443,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -464,7 +456,6 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -499,7 +490,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None) @@ -507,7 +498,6 @@ def test_stop_via_update_from_output(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -556,7 +546,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -574,7 +563,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -589,7 +577,7 @@ def test_preempt_during_execution(): block_size=16, num_blocks=11, enable_prefix_caching=False) - requests = create_requests(num_requests=2, num_tokens=80) + requests = create_requests(num_requests=2, num_tokens=80, block_size=16) # Schedule the first request. scheduler.add_request(requests[0]) @@ -610,7 +598,6 @@ def test_preempt_during_execution(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -628,7 +615,6 @@ def test_preempt_during_execution(): req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[42]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -684,13 +670,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], - spec_token_ids=spec_tokens, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) + draft_token_ids = DraftTokenIds(req_ids, spec_tokens) + scheduler.update_draft_token_ids(draft_token_ids) for i in range(len(requests)): running_req = scheduler.running[i] @@ -724,7 +711,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=output_tokens, - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -762,7 +748,7 @@ def _assert_right_scheduler_output( def _assert_right_kv_cache_manager( scheduler: Scheduler, - req_ids: list[str], + requests: list[Request], num_tokens: int, block_size: int, num_requests: int, @@ -772,12 +758,12 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size - for req_id in req_ids: + for req in requests: blocks = (scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[req_id]) - hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] + single_type_managers[0].req_to_blocks[req.request_id]) + hashes = req.block_hashes assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) + num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS @@ -840,7 +826,8 @@ def test_kv_connector_basic(): MAX_TOKENS = 3 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -852,7 +839,6 @@ def test_kv_connector_basic(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -868,7 +854,7 @@ def test_kv_connector_basic(): ) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS) # Continue Generation until done. @@ -886,7 +872,8 @@ def test_kv_connector_basic(): NUM_TOKENS = NUM_TOKENS_PREFIX * 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -898,7 +885,6 @@ def test_kv_connector_basic(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -915,7 +901,7 @@ def test_kv_connector_basic(): NUM_MATCHED_NEW_TOKENS)) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS) # Continue Generation until done. @@ -953,7 +939,8 @@ def test_kv_connector_unable_to_allocate(): MAX_TOKENS = 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -965,7 +952,6 @@ def test_kv_connector_unable_to_allocate(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1034,7 +1020,8 @@ def test_kv_connector_handles_preemption(): MAX_TOKENS = BLOCK_SIZE * 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -1046,7 +1033,6 @@ def test_kv_connector_handles_preemption(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1140,7 +1126,6 @@ def make_output(scheduler: Scheduler): for i, req in enumerate(scheduler.running) }, sampled_token_ids=[[1000]] * len(scheduler.running), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1162,7 +1147,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # KVCache Manager. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block) == 0 num_free_blocks = ( @@ -1325,13 +1309,7 @@ def create_requests_with_priority( for i in range(num_requests): if mm_positions is not None: mm_position = mm_positions[i] - mm_elem = MultiModalFieldElem( - modality="dummy_m", - key="dummy_k", - data=None, - field=MultiModalBatchedField(), - ) - mm_item = MultiModalKwargsItem.from_elems([mm_elem]) + mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_position) else: mm_position = None @@ -1473,7 +1451,6 @@ def test_priority_scheduling_preemption(): for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1546,7 +1523,6 @@ def test_priority_scheduling_no_preemption_when_space_available(): for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1788,7 +1764,6 @@ def test_priority_scheduling_heap_property(): req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, sampled_token_ids=[[100]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b67c05bd7ac10..7dcebba491fab 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, def get_sliding_window_manager(sliding_window_spec, block_pool): return SlidingWindowManager(sliding_window_spec, block_pool, - caching_hash_fn=lambda x: x, kv_cache_group_id=0) @@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): return ChunkedLocalAttentionManager(chunked_local_attention_spec, block_pool, - caching_hash_fn=lambda x: x, kv_cache_group_id=0) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 484afe61fc3fb..78a71f10a5940 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -6,10 +6,10 @@ import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import SamplingParams +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -114,6 +114,9 @@ def create_scheduler( ) +_none_hash_initialized = False + + def create_requests( num_requests: int, num_tokens: int = 10, @@ -122,7 +125,14 @@ def create_requests( stop_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[int] = None, same_prompt: bool = False, + block_size: int = 16, ) -> list[Request]: + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(hash) + _none_hash_initialized = True + + block_hasher = get_request_block_hasher(block_size, hash) sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, @@ -131,17 +141,17 @@ def create_requests( for i in range(num_requests): if mm_positions is not None: mm_position = mm_positions[i] - mm_elem = MultiModalFieldElem( - modality="dummy_m", - key="dummy_k", - data=None, - field=MultiModalBatchedField(), - ) - mm_item = MultiModalKwargsItem.from_elems([mm_elem]) + mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_position) + # Dummy hash for each mm item should be unique + # since encoder cache tracks entries by hash + mm_hashes = [ + "hash" + str(i) + "_" + str(j) for j in range(len(mm_position)) + ] else: mm_position = None mm_kwargs = None + mm_hashes = None prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * num_tokens) request = Request( @@ -151,8 +161,9 @@ def create_requests( pooling_params=None, multi_modal_kwargs=mm_kwargs, multi_modal_placeholders=mm_position, - multi_modal_hashes=None, + multi_modal_hashes=mm_hashes, eos_token_id=EOS_TOKEN_ID, + block_hasher=block_hasher, ) requests.append(request) return requests diff --git a/tests/v1/e2e/test_min_tokens.py b/tests/v1/e2e/test_min_tokens.py new file mode 100644 index 0000000000000..f013425cb59df --- /dev/null +++ b/tests/v1/e2e/test_min_tokens.py @@ -0,0 +1,479 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Comprehensive end-to-end tests for `min_tokens` in the V1 engine. + +Addresses #21950: verify and add CI coverage. + +Covers: +1) Basic functionality +2) Stop strings with `min_tokens` (bug #21987; fix in PR #22014) +3) EOS behavior with `min_tokens` (potential logits-processor bug) +4) Edge cases (min_tokens == max_tokens, min_tokens == 0) +5) Multiple stop conditions +""" + +import os +from typing import Optional, Union + +import pytest + +from vllm import LLM, SamplingParams +from vllm.outputs import RequestOutput + +# Test configuration +TEST_MODEL = "facebook/opt-125m" # Small model for fast CI execution +GREEDY = 0.0 # Deterministic generation for consistent testing + + +class MinTokensTestCase: + """Data class for min_tokens test scenarios""" + + def __init__( + self, + name: str, + min_tokens: int, + max_tokens: int, + stop: Optional[Union[str, list[str]]] = None, + expected_min_len: Optional[int] = None, + expected_exact_len: Optional[int] = None, + ): + self.name = name + self.min_tokens = min_tokens + self.max_tokens = max_tokens + self.stop = stop + self.expected_min_len = expected_min_len or min_tokens + self.expected_exact_len = expected_exact_len + + def __str__(self): + return (f"{self.name}: min={self.min_tokens}, " + f"max={self.max_tokens}, stop={self.stop}") + + +# Test scenarios covering all critical cases +MIN_TOKENS_TEST_CASES = [ + # === BASIC FUNCTIONALITY (should work) === + MinTokensTestCase(name="basic_min_tokens_no_stop", + min_tokens=8, + max_tokens=20, + stop=None, + expected_min_len=8), + MinTokensTestCase(name="min_tokens_zero", + min_tokens=0, + max_tokens=10, + stop=None, + expected_min_len=0), + MinTokensTestCase(name="min_equals_max_no_stop", + min_tokens=15, + max_tokens=15, + stop=None, + expected_exact_len=15), + + # === STOP STRINGS WITH MIN_TOKENS === + # These tests expose the detokenizer bug where stop strings + # bypass min_tokens + # Using mathematically guaranteed approach with wide stop nets + pytest.param( + MinTokensTestCase( + name="min_tokens_with_comprehensive_stops", + min_tokens=5, + max_tokens=20, + stop=[ + "a", + "e", + "i", + "o", + "u", + "t", + "n", + "s", + "r", + "l", + " ", + ], + expected_min_len=5, + ), + marks=pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False), + id="min_tokens_with_comprehensive_stops", + ), + pytest.param( + MinTokensTestCase( + name="min_tokens_with_simple_char_stop", + min_tokens=3, + max_tokens=15, + stop=["e", "a", " "], + expected_min_len=3, + ), + marks=pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False), + id="min_tokens_with_simple_char_stop", + ), + + # === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) === + # These test the MinTokensLogitsProcessor handling of EOS tokens + pytest.param( + MinTokensTestCase( + name="min_equals_max_eos_only", + min_tokens=20, + max_tokens=20, + stop=None, # Relies on default EOS token behavior + expected_exact_len=20, + ), + marks=pytest.mark.xfail( + reason= + ("Potential logits-processor bug: EOS tokens may bypass min_tokens" + ), + strict=False, + ), + id="min_equals_max_eos_only", + ), + + # === EDGE CASES === + MinTokensTestCase(name="large_min_tokens", + min_tokens=50, + max_tokens=60, + stop=None, + expected_min_len=50), + MinTokensTestCase( + name="min_tokens_with_empty_stop_list", + min_tokens=5, + max_tokens=15, + stop=[], # Empty stop list + expected_min_len=5), +] + + +@pytest.fixture(scope="module") +def llm_v1(): + """Create V1 LLM instance for testing""" + # Ensure V1 engine is used + os.environ["VLLM_USE_V1"] = "1" + + llm = LLM( + model=TEST_MODEL, + tensor_parallel_size=1, + max_model_len=1024, # Small context for fast testing + enforce_eager=True, # Avoid graph compilation overhead + ) + return llm + + +def get_token_count(output: RequestOutput) -> int: + """Extract token count from LLM output""" + if not output.outputs: + return 0 + return len(output.outputs[0].token_ids) + + +def assert_min_tokens_satisfied(output: RequestOutput, + test_case: MinTokensTestCase) -> None: + """Assert that min_tokens requirement is satisfied""" + token_count = get_token_count(output) + stop_reason = (output.outputs[0].stop_reason + if output.outputs else "no output") + + if test_case.expected_exact_len is not None: + # Exact length requirement + assert token_count == test_case.expected_exact_len, ( + f"Expected exactly {test_case.expected_exact_len} tokens, " + f"got {token_count} tokens. " + f"Stop reason: {stop_reason}") + else: + # Minimum length requirement + assert token_count >= (test_case.expected_min_len or 0), ( + f"Expected at least {test_case.expected_min_len} tokens, " + f"got {token_count} tokens. " + f"Stop reason: {stop_reason}") + + +@pytest.mark.parametrize( + "test_case", + MIN_TOKENS_TEST_CASES, + ids=lambda tc: tc.name, +) +def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): + """ + Comprehensive test for min_tokens functionality in V1 engine. + + This test covers all critical scenarios for min_tokens: + - Basic functionality (should work) + - Stop strings with min_tokens (known bug) + - EOS tokens with min_tokens (potential bug) + - Edge cases + + Args: + llm_v1: V1 LLM instance + test_case: Test scenario parameters + """ + # Known failing cases are handled via param-level xfail marks above. + + # Create sampling parameters + sampling_params = SamplingParams( + min_tokens=test_case.min_tokens, + max_tokens=test_case.max_tokens, + stop=test_case.stop, + temperature=GREEDY, + include_stop_str_in_output=True # Include stop strings for debugging + ) + + # Use simple prompt. Comprehensive stop lists should catch any generation + prompt = "Hello" + + # Generate output + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1, "Expected exactly one output" + output = outputs[0] + + # Debug information + token_count = get_token_count(output) + generated_text = output.outputs[0].text if output.outputs else "" + stop_reason = output.outputs[0].stop_reason if output.outputs else "unknown" + + print(f"\nTest: {test_case.name}") + print(f"Generated {token_count} tokens") + print(f"Stop reason: {stop_reason}") + print(f"Generated text: {repr(generated_text)}") + print(f"Expected min: {test_case.expected_min_len}") + if test_case.expected_exact_len: + print(f"Expected exact: {test_case.expected_exact_len}") + + # Validate min_tokens requirement + assert_min_tokens_satisfied(output, test_case) + + +def test_min_tokens_basic_functionality(llm_v1: LLM): + """ + Test basic min_tokens functionality without stop conditions. + + This is a baseline test that should always pass and validates + that min_tokens works correctly in the simple case. + """ + sampling_params = SamplingParams(min_tokens=10, + max_tokens=20, + temperature=GREEDY) + + prompt = "Once upon a time" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + + assert token_count >= 10, f"Expected at least 10 tokens, got {token_count}" + assert token_count <= 20, f"Expected at most 20 tokens, got {token_count}" + + +@pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False, +) +def test_min_tokens_stop_strings_bug(llm_v1: LLM): + """ + Test the specific bug where stop strings bypass min_tokens. + + This test specifically reproduces the bug Calvin is fixing in PR #22014. + It should fail until that fix is merged. + + Strategy: Use guaranteed stop characters that will appear + in any generated text. + """ + # If the bug is fixed upstream, this test will XPASS + + sampling_params = SamplingParams( + min_tokens=15, + max_tokens=50, + # Common letter; likely appears early + stop=["e"], + temperature=GREEDY, + include_stop_str_in_output=True) + + # Simple prompt that will generate text containing "e" + prompt = "The quick brown fox" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + generated_text = outputs[0].outputs[0].text if outputs[0].outputs else "" + + # Debug info to understand what happened + print(f"Generated text: {repr(generated_text)}") + print(f"Token count: {token_count}") + print(f"Contains 'e': {'e' in generated_text}") + + # This assertion should fail due to the bug - if stop string is found early, + # the model should still continue generating until min_tokens is reached + stop_reason = (outputs[0].outputs[0].stop_reason + if outputs[0].outputs else "no output") + assert token_count >= 15, ("Bug confirmed: " + f"{token_count} tokens < min_tokens=15. " + f"Reason: {stop_reason}. " + f"Text: {repr(generated_text)}") + + +@pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False, +) +def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): + """ + Guaranteed test for stop strings bypassing min_tokens bug. + + Strategy: Use very low temperature and multiple common stop strings + to virtually guarantee early detection, combined with long min_tokens + to ensure the bug is exposed regardless of model behavior. + """ + # If the bug is fixed upstream, this test will XPASS + + sampling_params = SamplingParams( + min_tokens=50, # Set high min_tokens to ensure bug detection + max_tokens=200, + # Use multiple very common patterns - at least one will appear + stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"], + temperature=GREEDY, + include_stop_str_in_output=True) + + # Simple prompt that will generate some text + prompt = "The cat" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + generated_text = outputs[0].outputs[0].text if outputs[0].outputs else "" + stop_reason = (outputs[0].outputs[0].stop_reason + if outputs[0].outputs else "unknown") + + print(f"Generated text: {repr(generated_text)}") + print(f"Token count: {token_count}") + print(f"Stop reason: {stop_reason}") + + # With the bug, this will fail because ANY of the common characters + # will trigger early termination before min_tokens=50 is reached + # It's virtually impossible to generate 50 tokens without hitting + # at least one of: e, a, i, o, u, space, t, n, s, r + finish_reason = (outputs[0].outputs[0].finish_reason + if outputs[0].outputs else "unknown") + + print(f"Finish reason: {finish_reason}") + + if finish_reason == "stop": + assert token_count >= 50, ("Bug confirmed: " + f"{token_count} tokens < min_tokens=50. " + f"Reason: {finish_reason}. " + f"Text: {repr(generated_text)}") + + +@pytest.mark.xfail( + reason=( + "Potential logits-processor bug: EOS tokens may bypass min_tokens"), + strict=False, +) +def test_min_tokens_eos_behavior(llm_v1: LLM): + """ + Verify EOS handling with and without min_tokens. + + - Without min_tokens: expect early EOS -> finish_reason == "stop", + stop_reason is None, and generated tokens < max_tokens (25). + - With min_tokens: EOS should be blocked until min_tokens is reached + (finish_reason == "length"); verify that eos_token_id does not appear + in generated token_ids. + """ + # tokenizer + eos id + tokenizer = llm_v1.get_tokenizer() + eos_token_id = tokenizer.eos_token_id + + prompt = "Give a file extension." + max_toks = 32 + + # Case 1: WITHOUT min_tokens + sp_no_min = SamplingParams( + max_tokens=max_toks, + temperature=GREEDY, + ) + out_no_min = llm_v1.generate([prompt], sp_no_min) + assert len(out_no_min) == 1 + choice_no_min = out_no_min[0].outputs[0] + + ids_no_min = choice_no_min.token_ids or [] + finish_no_min = choice_no_min.finish_reason + stop_no_min = choice_no_min.stop_reason + + print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min, + " stop_reason=", stop_no_min) + + assert finish_no_min == "stop", ( + f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}" + ) + assert stop_no_min is None, ( + "For EOS-based stop (no user stop strings), stop_reason should be None." + ) + assert len(ids_no_min) < max_toks, ( + f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}") + + # Case 2: WITH min_tokens + sp_with_min = SamplingParams( + min_tokens=max_toks, + max_tokens=max_toks, + temperature=GREEDY, + ) + out_with_min = llm_v1.generate([prompt], sp_with_min) + assert len(out_with_min) == 1 + choice_with_min = out_with_min[0].outputs[0] + + ids_with_min = choice_with_min.token_ids or [] + finish_with_min = choice_with_min.finish_reason + stop_with_min = choice_with_min.stop_reason + + print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min, + " stop_reason=", stop_with_min) + + # Exact length reached; EOS should have been blocked + assert len(ids_with_min) == max_toks, ( + f"Expected exactly {max_toks} tokens with min_tokens; " + f"got {len(ids_with_min)}") + assert finish_with_min == "length", ( + f"Expected finish_reason 'length'; got {finish_with_min}") + assert eos_token_id not in ids_with_min, ( + "EOS token id should not appear when min_tokens prevents early EOS.") + + +def test_min_tokens_validation(): + """ + Test that SamplingParams correctly validates min_tokens parameters. + + This tests the parameter validation logic in SamplingParams. + """ + # Valid cases + SamplingParams(min_tokens=0, max_tokens=10) + SamplingParams(min_tokens=5, max_tokens=10) + SamplingParams(min_tokens=10, max_tokens=10) + + # Invalid cases + with pytest.raises( + ValueError, + match="min_tokens must be greater than or equal to 0", + ): + SamplingParams(min_tokens=-1, max_tokens=10) + + with pytest.raises( + ValueError, + match="min_tokens must be less than or equal to max_tokens", + ): + SamplingParams(min_tokens=15, max_tokens=10) + + +if __name__ == "__main__": + """ + Run tests locally for development. + + Usage: + cd vllm/ + VLLM_USE_V1=1 python -m pytest tests/v1/e2e/test_min_tokens.py -v + """ + pytest.main([__file__, "-v"]) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 7b3f458312792..bd0fa6b80781a 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -144,6 +144,8 @@ def test_ngram_correctness( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), ], ids=[ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 @@ -151,7 +153,8 @@ def test_ngram_correctness( "llama3_eagle", "llama3_eagle3", "llama4_eagle", - "llama4_eagle_mm" + "llama4_eagle_mm", + "deepseek_eagle" ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @@ -177,6 +180,7 @@ def test_eagle_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) if (attn_backend == "TRITON_ATTN_VLLM_V1" diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 484640233f522..df04a14af70ce 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -212,6 +212,79 @@ async def test_abort( assert not engine.output_processor.has_unfinished_requests() +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_multi_abort( + monkeypatch: pytest.MonkeyPatch, + output_kind: RequestOutputKind, +): + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + after.callback(engine.shutdown) + + NUM_REQUESTS = 50 + NUM_EXPECTED_TOKENS = 100 + NUM_EXPECTED_TOKENS_LONG = 50000 + REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25] + PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35] + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks: list[asyncio.Task] = [] + for idx, request_id in enumerate(request_ids): + max_tokens = (NUM_EXPECTED_TOKENS_LONG if + (idx + in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS) + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 + tasks.append( + asyncio.create_task( + generate(engine, request_id, TEXT_PROMPT, output_kind, + max_tokens, n))) + + # Let requests start + await asyncio.sleep(0.5) + + # Use multi-abort to abort multiple requests at once + abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT] + await engine.abort(abort_request_ids) + + # Wait for all tasks to complete + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify results + for idx, result in enumerate(results): + if idx in REQUEST_IDS_TO_ABORT: + # Aborted requests should return partial results + assert isinstance( + result, tuple + ), f"Request {idx} should have completed with partial results" + num_generated_tokens, request_id = result + # Should have generated some tokens before abort + assert num_generated_tokens > 0, ( + f"Aborted request " + f"{request_id} should have generated some tokens") + else: + # Non-aborted requests should complete normally + assert isinstance( + result, + tuple), f"Request {idx} should have completed successfully" + num_generated_tokens, request_id = result + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 + expected_tokens = NUM_EXPECTED_TOKENS * n + assert num_generated_tokens == expected_tokens, ( + f"{request_id} generated {num_generated_tokens} but " + f"expected {expected_tokens}") + + # Make sure all aborted requests were cleaned up + assert not engine.output_processor.has_unfinished_requests() + + @pytest.mark.parametrize("n", [1, 3]) @pytest.mark.parametrize( "engine_args,prompt", @@ -460,7 +533,9 @@ async def test_abort_final_output( token_count = sum( len(output.outputs[0].token_ids) for output in outputs) assert token_count > 0 - assert len(final_output.outputs[0].token_ids) == 0 + # This would ordinarily be 0, but could end up > 0 if the + # final abort is coalesced with another chunk in the output queue. + assert len(final_output.outputs[0].token_ids) >= 0 else: # For FINAL_ONLY, we should only get the final output assert len(outputs) == 0 diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index c82285639aee4..37eb869fe69a3 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -121,8 +121,13 @@ async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict): # Dummy utility function to monkey-patch into engine core. -def echo(self, msg: str, err_msg: Optional[str] = None) -> str: +def echo(self, + msg: str, + err_msg: Optional[str] = None, + sleep: Optional[float] = None) -> str: print(f"echo util function called: {msg}, {err_msg}") + if sleep is not None: + time.sleep(sleep) if err_msg is not None: raise ValueError(err_msg) return msg @@ -289,6 +294,23 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): await core_client.call_utility_async("echo", None, "help!") assert str(e_info.value) == "Call to echo method failed: help!" + + # Test that cancelling the utility call doesn't destabilize the + # engine. + util_task = asyncio.create_task( + core_client.call_utility_async("echo", "testarg2", None, + 0.5)) # sleep for 0.5 sec + await asyncio.sleep(0.05) + cancelled = util_task.cancel() + assert cancelled + + # Ensure client is still functional. The engine runs utility + # methods in a single thread so this request won't be processed + # until the cancelled sleeping one is complete. + result = await asyncio.wait_for(core_client.call_utility_async( + "echo", "testarg3"), + timeout=1.0) + assert result == "testarg3" finally: client.shutdown() diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 8bddfb0b48a50..cd82eb2ac4199 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any import jsonschema import pytest import regex as re +import torch from pydantic import BaseModel from tests.reasoning.utils import run_reasoning_extraction +from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform @@ -39,8 +41,11 @@ EAGLE_SPEC_CONFIG = { PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", + None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), + ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", @@ -127,13 +132,15 @@ def test_structured_output( temperature=1.0, max_tokens=4096, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) - outputs = llm.generate(prompts=[ - (f"Give an example JSON for an employee profile that fits this " - f"schema. Make the response as short as possible. Schema: " - f"{sample_json_schema}") - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + + prompt = ("Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") + outputs = llm.generate( + [prompt] * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -144,7 +151,8 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None - assert "\n" not in generated_text + if guided_decoding_backend != 'lm-format-enforcer': + assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) @@ -191,20 +199,24 @@ def test_structured_output( with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): + + prompt = (f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.") llm.generate( - prompts=[(f"Give an example JSON for an employee profile that " - f"fits this schema: {unsupported_json_schema}. " - f"Make the response as short as possible.")] * 2, + [prompt] * 2, sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) else: - outputs = llm.generate(prompts=( - "Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True) + prompt = (f"Give an example JSON object for a grade that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.") + outputs = llm.generate( + prompt, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None for output in outputs: assert output is not None @@ -217,7 +229,7 @@ def test_structured_output( parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - if guided_decoding_backend != "outlines": + if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: # # Test 4: Generate SQL statement using EBNF grammar # @@ -227,10 +239,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -261,10 +272,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -301,7 +311,6 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(grammar="not a grammar")) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - prompts= ("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short " "as possible."), @@ -316,11 +325,11 @@ def test_structured_output( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(regex=sample_regex)) + + prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible.") outputs = llm.generate( - prompts=[ - (f"Give an example IPv4 address with this regex: {sample_regex}. " - f"Make the response as short as possible.") - ] * 2, + [prompt] * 2, sampling_params=sampling_params, use_tqdm=True, ) @@ -343,11 +352,13 @@ def test_structured_output( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) + outputs = llm.generate( - prompts=("The best language for type-safe systems programming is " - "(Make the response as short as possible.) "), + ("The best language for type-safe systems programming is " + "(Make the response as short as possible.) "), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None for output in outputs: assert output is not None @@ -367,12 +378,14 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) - outputs = llm.generate(prompts=( - "Generate a JSON with the brand, model and car_type of the most " - "iconic car from the 90's. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True) + + outputs = llm.generate( + ("Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -411,10 +424,11 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( - prompts=("Generate a description of a frog using 50 characters. " - "Make the response as short as possible."), + ("Generate a description of a frog using 50 characters. " + "Make the response as short as possible."), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None @@ -429,7 +443,7 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) - if guided_decoding_backend != "outlines": + if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: # # Test 11: Generate structured output using structural_tag format # @@ -498,7 +512,7 @@ Make the response as short as possible. """ # Change this once other backends support structural_tag - outputs = llm.generate(prompts=prompt, + outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None @@ -639,15 +653,13 @@ def test_structured_output_auto_mode( f"{unsupported_json_schema}. Make the response as short as possible.") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. - outputs = llm.generate(prompts=prompts, + outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) # Make sure `auto` backend handling doesn't mess up sampling_params # and that we can reuse it without error. outputs.extend( - llm.generate(prompts=prompts, - sampling_params=sampling_params, - use_tqdm=True)) + llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)) assert outputs is not None for output in outputs: @@ -705,7 +717,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): max_tokens=256, guided_decoding=guided_params) - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) assert outputs is not None generated_text = outputs[0].outputs[0].text assert generated_text is not None @@ -721,3 +733,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): assert "a4" not in generated assert "a5" not in generated assert "a6" not in generated + + +@pytest.mark.parametrize("guided_decoding_backend", + ["guidance", "xgrammar", "outlines"]) +def test_structured_output_batched_with_non_guided_requests( + monkeypatch: pytest.MonkeyPatch, + sample_json_schema: dict[str, Any], + guided_decoding_backend: str, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager = bool(not current_platform.is_tpu()) + + llm = LLM( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + enforce_eager=enforce_eager, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=(guided_decoding_backend + in {"xgrammar", "guidance"}), + ) + + guided_prompt = ( + "Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") + + non_guided_prompt = "The diameter of the Earth in kilometers is " + + prompts = [guided_prompt, non_guided_prompt] + sampling_params = [ + SamplingParams( + temperature=1.0, + max_tokens=400, + guided_decoding=GuidedDecodingParams(json=sample_json_schema)), + # No max tokens, temp=0 to assert on contents + SamplingParams( + seed=42, + temperature=0, + top_p=1.0, + ), + ] + + outputs = llm.generate(prompts=prompts, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + # Free memory as soon as possible as failed assertions + # will short circuit and not free up memory + del llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + for index, output in enumerate(outputs): + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}") + + if index == 0: + # First prompt is guided, expect valid JSON + assert "\n" not in generated_text + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_json_schema) + else: + # Second prompt is not guided, expect valid output + # Cannot assert on exact output, but we can expect it to be factual + assert "12,742" in generated_text + + # non-guided requests should not return a valid JSON here + with pytest.raises(ValueError): + output_json = json.loads(generated_text) diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py index 18c35152e7b20..7a0baa5767cba 100644 --- a/tests/v1/entrypoints/openai/responses/test_basic.py +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI): ], ) print(response) assert response.status == "completed" + + +@pytest.mark.asyncio +async def test_logprobs(client: openai.AsyncOpenAI): + response = await client.responses.create( + include=["message.output_text.logprobs"], + input="What is 13 * 24?", + top_logprobs=5, + ) + print(response) + outputs = response.output + assert outputs[-1].content[-1].logprobs + assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5 diff --git a/tests/prefix_caching/__init__.py b/tests/v1/executor/__init__.py similarity index 100% rename from tests/prefix_caching/__init__.py rename to tests/v1/executor/__init__.py diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py new file mode 100644 index 0000000000000..bdd5155c1481d --- /dev/null +++ b/tests/v1/executor/test_executor.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import os +from typing import Any, Callable, Optional, Union + +import pytest + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.sampling_params import SamplingParams +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.executor.multiproc_executor import MultiprocExecutor + + +class Mock: + ... + + +class CustomMultiprocExecutor(MultiprocExecutor): + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + unique_reply_rank: Optional[int] = None) -> list[Any]: + # Drop marker to show that this was ran + with open(".marker", "w"): + ... + return super().collective_rpc(method, timeout, args, kwargs) + + +CustomMultiprocExecutorAsync = CustomMultiprocExecutor +MODEL = "Qwen/Qwen3-0.6B" + + +def test_custom_executor_type_checking(): + with pytest.raises(ValueError): + engine_args = EngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=Mock, + ) + LLMEngine.from_engine_args(engine_args) + with pytest.raises(ValueError): + engine_args = AsyncEngineArgs(model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=Mock) + AsyncLLM.from_engine_args(engine_args) + + +@pytest.mark.parametrize("distributed_executor_backend", [ + CustomMultiprocExecutor, + "tests.v1.executor.test_executor.CustomMultiprocExecutor" +]) +def test_custom_executor(distributed_executor_backend, tmp_path): + cwd = os.path.abspath(".") + os.chdir(tmp_path) + try: + assert not os.path.exists(".marker") + + engine_args = EngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, # reduce test time + ) + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(max_tokens=1) + + engine.add_request("0", "foo", sampling_params) + engine.step() + + assert os.path.exists(".marker") + finally: + os.chdir(cwd) + + +@pytest.mark.parametrize("distributed_executor_backend", [ + CustomMultiprocExecutorAsync, + "tests.v1.executor.test_executor.CustomMultiprocExecutorAsync" +]) +def test_custom_executor_async(distributed_executor_backend, tmp_path): + cwd = os.path.abspath(".") + os.chdir(tmp_path) + try: + assert not os.path.exists(".marker") + + engine_args = AsyncEngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, # reduce test time + ) + engine = AsyncLLM.from_engine_args(engine_args) + sampling_params = SamplingParams(max_tokens=1) + + async def t(): + stream = engine.generate(request_id="0", + prompt="foo", + sampling_params=sampling_params) + async for x in stream: + ... + + asyncio.run(t()) + + assert os.path.exists(".marker") + finally: + os.chdir(cwd) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b185936ab025f..040b44dc5d2ca 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -14,6 +14,7 @@ from unittest.mock import patch import pytest import ray +import torch from vllm import LLM from vllm.config import KVTransferConfig @@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorWorker) from vllm.forward_context import ForwardContext from vllm.sampling_params import SamplingParams +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from .utils import create_request, create_scheduler, create_vllm_config @@ -98,7 +100,6 @@ class FakeNixlWrapper: def set_cycles_before_xfer_done(self, cycles: int): """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles @contextlib.contextmanager @@ -147,6 +148,7 @@ def test_basic_interface(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) request_id = request.request_id @@ -186,6 +188,7 @@ def test_prompt_less_than_block_size(): # Request will have 1 partial remote block. request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True, num_remote_blocks=1) @@ -560,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): sampling_params) # Request-0 times out and is cleared! assert '0' not in req_to_blocks + + +def test_register_kv_caches(dist_init): + """ + Test that register_kv_caches() properly calls nixl_wrapper methods with + correct data. + + This test verifies: + 1. nixl_wrapper.get_reg_descs() is called with caches_data containing + tensor metadata + 2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing + block layout info + """ + + vllm_config = create_vllm_config() + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2, + block_size=16, + num_kv_heads=4, + head_size=64) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + expected_tensor_size = shared_tensor[0].element_size( + ) * shared_tensor[0].numel() + expected_base_addrs = [ + shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr() + ] + + with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501 + + # Create connector + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + + # Get the mock instance + mock_wrapper_instance = mock_nixl_wrapper.return_value + connector.connector_worker.nixl_wrapper = mock_wrapper_instance + + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Verify get_reg_descs was called with caches_data + assert mock_wrapper_instance.get_reg_descs.called + caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] + assert len(caches_data) == 4 + + for i, cache_entry in enumerate(caches_data): + base_addr, size, _tp_rank, _ = cache_entry + assert size == expected_tensor_size, \ + f"Entry {i}: Expected tensor size {expected_tensor_size}, " \ + f"got {size}" + assert base_addr == expected_base_addrs[i], \ + f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \ + f"got {base_addr}" + + # Verify get_xfer_descs was called with blocks_data + assert mock_wrapper_instance.get_xfer_descs.called + blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] + + # Validate blocks_data structure and size + expected_blocks_count = 8 + assert len(blocks_data) == expected_blocks_count, \ + f"Expected {expected_blocks_count} blocks, " \ + f"got {len(blocks_data)}" + + expected_block_len = expected_tensor_size // 2 + for i, block_entry in enumerate(blocks_data): + block_start_addr, block_len, tp_rank = block_entry + assert block_len == expected_block_len, \ + f"Block entry {i}: Expected block len {expected_block_len}, " \ + f"got {block_len}" diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 2f8228864e7b4..d8c56ac42f718 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -21,6 +21,7 @@ def test_basic_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request = create_request(request_id=1, + block_size=BLOCK_SIZE, max_tokens=1, num_tokens=NUM_TOKENS, do_remote_decode=True) @@ -103,8 +104,10 @@ def test_short_prompt_lifecycle(): scheduler = create_scheduler(vllm_config) # Not enough tokens for full block. - NUM_TOKENS = vllm_config.cache_config.block_size // 2 + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_TOKENS = BLOCK_SIZE // 2 request = create_request(request_id=1, + block_size=BLOCK_SIZE, max_tokens=1, num_tokens=NUM_TOKENS, do_remote_decode=True) @@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 3 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS) + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS) scheduler.add_request(request_normal) scheduler_output = scheduler.schedule() @@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request_remote = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=True) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 87f7490698a31..21fec5344255c 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -23,6 +23,7 @@ def test_basic_lifecycle(): scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -133,14 +134,17 @@ def test_interleaved_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request_remote = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) request_local_a = create_request( request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, ) request_local_b = create_request( request_id=3, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, ) @@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching(): # Both of these requests have prompts like [1,1,1,1,1, ...] request_remote = create_request( request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True, use_all_1s_for_prompt_tokens=True, @@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching(): request_local = create_request( request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=False, use_all_1s_for_prompt_tokens=True, @@ -292,6 +298,7 @@ def test_full_block_prompt(): NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv(): NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) - request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_LOCAL) request_remote = create_request(request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_REMOTE, do_remote_prefill=True) @@ -456,8 +466,11 @@ def test_cannot_recv(): NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_LOCAL) request_remote = create_request(request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_REMOTE, do_remote_prefill=True) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 60847c48585c6..a47f583b329e2 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile from collections import defaultdict -from typing import Any, Optional +from typing import Any, Callable, Optional import torch @@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector) from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) @@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # KVCache Manager. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block) == 0 num_free_blocks = ( @@ -115,16 +116,23 @@ def create_scheduler( ) -def create_request( - request_id: int, - num_tokens: int = 10, - max_tokens: int = 16, - do_remote_decode: bool = False, - do_remote_prefill: bool = False, - use_all_1s_for_prompt_tokens: bool = False, - num_remote_blocks: int = 3, -) -> Request: +_none_hash_initialized = False + + +def create_request(request_id: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, + num_remote_blocks: int = 3, + block_size: int = 16, + hash_fn: Callable = hash) -> Request: """Make dummy request for testing.""" + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(hash) + _none_hash_initialized = True kv_transfer_params: Optional[dict[str, Any]] = None @@ -158,6 +166,7 @@ def create_request( multi_modal_placeholders=None, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, + block_hasher=get_request_block_hasher(block_size, hash_fn), ) req.kv_transfer_params = kv_transfer_params return req @@ -191,7 +200,6 @@ def create_model_runner_output( req_ids=req_ids, req_id_to_index=req_id_to_index, sampled_token_ids=sampled_token_ids, - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=None, diff --git a/tests/v1/logits_processors/__init__.py b/tests/v1/logits_processors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/sample/test_logits_processors.py b/tests/v1/logits_processors/test_correctness.py similarity index 97% rename from tests/v1/sample/test_logits_processors.py rename to tests/v1/logits_processors/test_correctness.py index 84ee3b0392b40..43caef79b02f7 100644 --- a/tests/v1/sample/test_logits_processors.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -9,11 +9,13 @@ import numpy as np import pytest import torch +from tests.utils import create_new_process_for_each_test from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits, create_penalty_tensor, create_prompt_tokens_tensor, fake_apply_logitsprocs, fake_update_logitsprocs_state) +from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available @@ -24,7 +26,7 @@ from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder, MinPLogitsProcessor, MinTokensLogitsProcessor, MoveDirectionality, - init_builtin_logitsprocs) + build_logitsprocs) # yapf: enable from vllm.v1.sample.metadata import SamplingMetadata @@ -53,6 +55,7 @@ class LogitsProcsRequestParams: workload_index: int logitproc_type: LogitprocType # Logitproc enabled, specified by str id out_tokens: list[int] # Output tokens required for min tokens test + prompt_tokens: list[int] # Dummy prompt tokens placeholder params: SamplingParams # Settings customized for logitproc def __init__(self, workload_index: int, logitproc_type: LogitprocType): @@ -63,6 +66,7 @@ class LogitsProcsRequestParams: # don't matter *for these tests* so use 0 as a dummy value self.out_tokens = ([0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))) + self.prompt_tokens = [] self.params = _sampling_params_from_logitproc(logitproc_type) def __str__(self): @@ -88,11 +92,12 @@ def _generate_fake_sampling_metadata( vocab_size, size=np.random.randint( 1, MAX_NUM_PROMPT_TOKENS)).tolist()) - logitsprocs = init_builtin_logitsprocs( - pin_memory_available=PIN_MEMORY_AVAILABLE, - max_num_reqs=MAX_NUM_REQS + 1, - device=device) - + logitsprocs = build_logitsprocs( + vllm_config=VllmConfig(), + device=device, + is_pin_memory=PIN_MEMORY_AVAILABLE, + is_pooling_model=False, + ) fake_sampling_metadata = SamplingMetadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, @@ -462,7 +467,8 @@ def _generate_fake_step_update( # Replace as many removed requests as possible with added requests add_remove_idx = batch_update_builder.pop_removed() batch_update_builder.added.append( - (add_remove_idx, add_req_params.params, add_req_params.out_tokens)) + (add_remove_idx, add_req_params.params, + add_req_params.prompt_tokens, add_req_params.out_tokens)) persistent_batch[add_remove_idx] = add_req_params # Append remaining added requests to end of batch @@ -470,7 +476,8 @@ def _generate_fake_step_update( num_step_add_replace):(wdx + num_step_add)] batch_update_builder.added.extend([ - (adx + batch_size, add_req_params.params, add_req_params.out_tokens) + (adx + batch_size, add_req_params.params, add_req_params.prompt_tokens, + add_req_params.out_tokens) for adx, add_req_params in enumerate(add_reqs_append) ]) persistent_batch.extend(add_reqs_append) @@ -561,6 +568,7 @@ def _assert_valid( step_idx=step_idx) +@create_new_process_for_each_test() @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC]) @pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases()) diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py new file mode 100644 index 0000000000000..a7fde1990f7ed --- /dev/null +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +import sys +from typing import Union + +import pytest + +from tests.utils import create_new_process_for_each_test +# yapf: disable +from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, MODEL_NAME, + POOLING_MODEL_NAME, TEMP_GREEDY, + CustomLogitprocSource, + DummyLogitsProcessor, + dummy_module) +from tests.v1.logits_processors.utils import entry_points as fake_entry_points +from tests.v1.logits_processors.utils import prompts +# yapf: enable +from vllm import LLM, SamplingParams +from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS, + LogitsProcessor) + +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 128}), + SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), + SamplingParams(temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 67}), + SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), +] + + +def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: + """Compare `LLM` instance initialized with specified `kwargs` against + reference `LLM` instance. + + Two scenarios: + 1. Server has loaded dummy logitproc; test that requests which specify + dummy logitproc arg value behave as if logitproc is operating (output + token value should repeat), while requests that don't specify dummy + logitproc arg value should match reference `LLM` output. + 2. Server has *not* loaded dummy logitproc; test that all requests + behave as if logitproc is *not* operating (output matches reference + `LLM` output.) + + Args: + kwargs: `LLM` constructor kwargs + logitproc_loaded: server has loaded dummy logitproc if True + """ + + # Create a vLLM instance and load custom logitproc + llm_logitproc = LLM( + model=MODEL_NAME, + gpu_memory_utilization=0.1, + **kwargs, + ) + + # Create a reference vLLM instance without custom logitproc + llm_ref = LLM(model=MODEL_NAME, gpu_memory_utilization=0.1) + + # Run inference with logitproc loaded + outputs_logitproc = llm_logitproc.generate(prompts, sampling_params_list) + + # Reference run + outputs_ref = llm_ref.generate(prompts, sampling_params_list) + + # Validate outputs + for bdx, (out_lp, out_ref, params) in enumerate( + zip(outputs_logitproc, outputs_ref, sampling_params_list)): + lp_toks = out_lp.outputs[0].token_ids + if logitproc_loaded and params.extra_args: + # This request exercises custom logitproc; validate that logitproc + # forces `target_token` to be decoded in each step + target_token = params.extra_args[DUMMY_LOGITPROC_ARG] + if not all(x == target_token for x in lp_toks): + raise AssertionError( + f"Request {bdx} generated {lp_toks}, shoud all be " + f"{target_token}") + else: + # This request does not exercise custom logitproc (or custom + # logitproc is not enabled on this server); validate against + # reference result + ref_toks = out_ref.outputs[0].token_ids + if lp_toks != ref_toks: + raise AssertionError( + f"Request {bdx} generated {lp_toks}, should match " + f"{ref_toks}") + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource)) +def test_custom_logitsprocs(monkeypatch, + logitproc_source: CustomLogitprocSource): + """Test offline Python interface for passing custom logitsprocs + + Construct an `LLM` instance which loads a custom logitproc that has a + well-defined behavior (mask out all tokens except one `target_token`) + + Construct a reference `LLM` instance with no custom logitproc + + Pass in a batch of requests, 50% of which pass a `target_token` value + in through `SamplingParams.extra_args`, 50% of which do not. + + Validate that + * Requests which do not activate the custom logitproc, yield the same + results for both `LLM` instances + * Requests which activate the custom logitproc, only output `target_token` + + Test four scenarios, corresponding to `logitproc_source` value + * No logitsprocs loaded - test that generated tokens match reference `LLM` + instance output + * Logitproc passed in via {entrypoint, class object, fully-qualified class + name (FQCN)} - test that dummy logitproc is utilized correctly when + provided via any of these three possible sources + + Args: + monkeypatch: for setting env vars + logitproc_source: what source (entrypoint, fully-qualified class name + (FQCN), class object, or None) the user pulls the + logitproc from + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + random.seed(40) + + # Choose LLM args based on logitproc source + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE: + # Scenario: the server does not load any custom logitproc + # Every other scenario is a different way of loading a custom logitproc + _run_test({}, logitproc_loaded=False) + return + + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: + # Scenario: vLLM loads a logitproc from a preconfigured entrypoint + # To that end, mock a dummy logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + + # fork is required for workers to see entrypoint patch + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") + _run_test({}, logitproc_loaded=True) + return + + kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: + # Scenario: load logitproc based on fully-qualified class name (FQCN) + # Inject dummy module which defines logitproc + sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module + kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: + # Scenario: load logitproc from provided class object + kwargs["logits_processors"] = [DummyLogitsProcessor] + + _run_test(kwargs, logitproc_loaded=True) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("logitproc_source", [ + CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, + CustomLogitprocSource.LOGITPROC_SOURCE_FQCN, + CustomLogitprocSource.LOGITPROC_SOURCE_CLASS, +]) +def test_pooling_rejects_custom_logitsprocs( + monkeypatch, logitproc_source: CustomLogitprocSource): + """Validate that vLLM engine initialization properly rejects custom + logitsprocs when the model is a pooling model. + + Use `LLM` entrypoint. We expect `LLM` initialization to fail before the + logitproc is actually loaded. + + Scenario 1: + * Mock a logitproc entrypoint + * Validate that `LLM` does not load the logitproc + + Scenario 2: + * Pass custom logitproc to `LLM` constructor + * Scenario 2a: via FQCN + * Scenario 2b: via class object + * Validate that initialization fails with appropriate exception + + Args: + monkeypatch: used to set environment variables + logitproc_source: what source (entrypoint, fully-qualified class name + (FQCN), or class object) the user pulls the + logitproc from + """ + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + random.seed(40) + + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: + # Scenario: vLLM loads a pooling model and ignores a logitproc that is + # available at a preconfigured entrypoint + + # Patch in dummy logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + + # fork is required for entrypoint patch to be visible to workers, + # although they should ignore the entrypoint patch anyway + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") + + llm = LLM( + runner="pooling", + model=POOLING_MODEL_NAME, + gpu_memory_utilization=0.1, + ) + # Require that no logitsprocs have been loaded + assert sum([ + 1 for _ in llm.llm_engine.model_executor.driver_worker.worker. + model_runner.input_batch.logitsprocs.all + ]) == 0 + return + + kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: + # Scenario: load logitproc based on fully-qualified class name (FQCN) + kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: + # Scenario: load logitproc from provided class object + kwargs["logits_processors"] = [DummyLogitsProcessor] + + with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS): + # Require that loading a pooling model alongside the logitproc raises + # the appropriate exception. + LLM( + runner="pooling", + model=POOLING_MODEL_NAME, + gpu_memory_utilization=0.1, + **kwargs, + ) diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py new file mode 100644 index 0000000000000..a01a479e5b248 --- /dev/null +++ b/tests/v1/logits_processors/test_custom_online.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import random +import sys +from typing import Any, Optional + +import openai +import pytest +import pytest_asyncio + +from tests.utils import (RemoteOpenAIServerCustom, + create_new_process_for_each_test) +# yapf: disable +from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, MODEL_NAME, + TEMP_GREEDY, dummy_module) +from tests.v1.logits_processors.utils import entry_points as fake_entry_points +from tests.v1.logits_processors.utils import prompts + +# yapf: enable + + +def _server_with_logitproc_entrypoint( + env_dict: Optional[dict[str, str]], + model: str, + vllm_serve_args: list[str], +) -> None: + """Start vLLM server, inject dummy logitproc entrypoint""" + + # Patch `entry_points` to inject logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + from vllm.entrypoints.cli import main + + # fork is required for workers to see entrypoint patch + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + if env_dict is not None: + os.environ.update(env_dict) + + # Emulate `vllm serve <model> <CLI args>` + sys.argv = ["vllm", "serve", model] + vllm_serve_args + main.main() + + +def _server_with_logitproc_module( + env_dict: Optional[dict[str, str]], + model: str, + vllm_serve_args: list[str], +) -> None: + """Start vLLM server, inject module with dummy logitproc""" + + # Patch `modules` to inject dummy logitproc module + from vllm.entrypoints.cli import main + sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module + + # fork is required for workers to see entrypoint patch + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + if env_dict is not None: + os.environ.update(env_dict) + + # Emulate `vllm serve <model> <CLI args>` + sys.argv = ["vllm", "serve", model] + vllm_serve_args + main.main() + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + ] + + +@pytest.fixture(scope="function", + params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]) +def server(default_server_args, request, monkeypatch): + """Consider two server configurations: + (1) --logits-processors cli arg specifies dummy logits processor via fully- + qualified class name (FQCN); patch in a dummy logits processor module + (2) No --logits-processors cli arg; patch in a dummy logits processor + entrypoint + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + + if request.param: + # Launch server, append FQCN argument, inject dummy logitproc module + args = default_server_args + request.param + _server_fxn = _server_with_logitproc_module + else: + # Launch server, inject dummy logitproc entrypoint + args = default_server_args + _server_fxn = _server_with_logitproc_entrypoint + + with RemoteOpenAIServerCustom(MODEL_NAME, args, + _server_fxn) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +# General request argument values for these tests +api_keyword_args = { + # Greedy sampling ensures that requests which receive the `target_token` + # arg will decode it in every step + "temperature": TEMP_GREEDY, + # Since EOS will never be decoded (unless `target_token` is EOS) + "max_tokens": MAX_TOKENS, + # Return decoded token logprobs (as a way of getting token id) + "logprobs": 0, +} + + +@create_new_process_for_each_test() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): + """Test custom logitsprocs when starting OpenAI server from CLI + + Launch vLLM OpenAI-compatible server, configured to load a custom logitproc + that has a well-defined behavior (mask out all tokens except one + `target_token`). + + Pass in requests, 50% of which pass a `target_token` value + in through `extra_body["vllm_xargs"]`, 50% of which do not. + + Validate that requests which activate the custom logitproc, repeat the same + token + """ + + use_dummy_logitproc = True + for prompt in prompts: + # Build request arguments + request_keyword_args: dict[str, Any] = { + **api_keyword_args, + } + if use_dummy_logitproc: + # 50% of requests pass target_token custom arg + target_token = random.choice([128, 67]) + # For requests which activate the dummy logitproc, choose one of + # two `target_token` values which are known not to be EOS tokens + request_keyword_args["extra_body"] = { + "vllm_xargs": { + DUMMY_LOGITPROC_ARG: target_token + } + } + batch = await client.completions.create( + model=model_name, + prompt=prompt, + **request_keyword_args, + ) + + if use_dummy_logitproc: + # Only for requests which activate dummy logitproc - validate that + # output token is repeated + choices: openai.types.CompletionChoice = batch.choices + toks = choices[0].logprobs.tokens + if not all([x == toks[0] for x in toks]): + raise AssertionError( + f"Generated {toks} should all be {toks[0]}") + + # Alternate whether to activate dummy logitproc for each request + use_dummy_logitproc = not use_dummy_logitproc diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py new file mode 100644 index 0000000000000..c0bfc1a18feca --- /dev/null +++ b/tests/v1/logits_processors/utils.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types +from enum import Enum, auto +from typing import Optional + +import torch + +from vllm.config import VllmConfig +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate, + LogitsProcessor, + MoveDirectionality) + +MODEL_NAME = "facebook/opt-125m" +POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" +DUMMY_LOGITPROC_ARG = "target_token" +TEMP_GREEDY = 0.0 +MAX_TOKENS = 20 +DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc" +DUMMY_LOGITPROC_MODULE = "DummyModule" +DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor" + + +class CustomLogitprocSource(Enum): + """How to source a logitproc for testing purposes""" + LOGITPROC_SOURCE_NONE = auto() # No custom logitproc + LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint + LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN) + LOGITPROC_SOURCE_CLASS = auto() # Via provided class object + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class DummyLogitsProcessor(LogitsProcessor): + """Fake logit processor to support unit testing and examples""" + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + self.req_info: dict[int, SamplingParams] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + # Process added requests. + for index, params, _, _ in batch_update.added: + assert params is not None + if params.extra_args and (target_token := + params.extra_args.get("target_token")): + self.req_info[index] = target_token + + if self.req_info: + # Process removed requests. + for index in batch_update.removed: + self.req_info.pop(index, None) + + # Process moved requests, unidirectional move (a->b) and swap + # (a<->b) + for adx, bdx, direct in batch_update.moved: + a_val = self.req_info.pop(adx, None) + b_val = self.req_info.pop(bdx, None) + if a_val is not None: + self.req_info[bdx] = a_val + if direct == MoveDirectionality.SWAP and b_val is not None: + self.req_info[adx] = b_val + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.req_info: + return logits + + # Save target values before modification + rows_list = list(self.req_info.keys()) + cols = torch.tensor([self.req_info[i] for i in rows_list], + dtype=torch.long, + device=logits.device) + rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) + values_to_keep = logits[rows, cols].clone() + + # Mask all but target tokens + logits[rows] = float('-inf') + logits[rows, cols] = values_to_keep + + return logits + + +"""Dummy module with dummy logitproc class""" +dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE) +dummy_module.DummyLogitsProcessor = DummyLogitsProcessor # type: ignore + + +class EntryPoint: + """Dummy entrypoint class for logitsprocs testing""" + + def __init__(self): + self.name = DUMMY_LOGITPROC_ENTRYPOINT + self.value = DUMMY_LOGITPROC_FQCN + + def load(self): + return DummyLogitsProcessor + + +class EntryPoints(list): + """Dummy EntryPoints class for logitsprocs testing""" + + def __init__(self, group: str): + # Emulate list-like functionality + eps = [EntryPoint()] if group == LOGITSPROCS_GROUP else [] + super().__init__(eps) + # Extra attributes + self.names = [ep.name for ep in eps] + + +"""Fake version of importlib.metadata.entry_points""" +entry_points = lambda group: EntryPoints(group) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 8bd142e87b06e..e835c029634ce 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -456,9 +456,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): assert len(logprob) == vocab_size -@pytest.mark.parametrize( - "logprobs_mode", - ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) +@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch): """Test with LLM engine with different logprobs_mode. @@ -487,12 +485,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode, for logprobs in output.logprobs: for token_id in logprobs: logprob = logprobs[token_id] - if "logprobs" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGPROBS, + LogprobsMode.PROCESSED_LOGPROBS): assert logprob.logprob <= 0 if logprob.logprob > 0: positive_values = positive_values + 1 total_token_with_logprobs = total_token_with_logprobs + 1 assert total_token_with_logprobs >= len(results[0].outputs) - if "logits" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGITS, + LogprobsMode.PROCESSED_LOGITS): assert positive_values > 0 del llm diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 3a4d48afc9d77..4e912f98f376f 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F from vllm.platforms import current_platform -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) @@ -69,7 +69,7 @@ def create_sampling_metadata( output_token_ids=[], allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 31c6c881d7b83..53215f88bb27e 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -9,7 +9,7 @@ import torch from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler @@ -173,7 +173,7 @@ def _create_default_sampling_metadata( no_penalties=True, allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) return fake_sampling_metadata diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 456ce712d36e4..6317817408661 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -50,6 +50,7 @@ def forward_attention( dtype=torch.int32, ) context_lens = seq_lens - query_lens + max_seq_len = int(seq_lens.max()) max_query_len = q_len num_actual_tokens = query_start_loc[-1] @@ -81,6 +82,7 @@ def forward_attention( num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table, slot_mapping=slot_mapping, ) diff --git a/tests/v1/test_kv_sharing.py b/tests/v1/test_kv_sharing.py new file mode 100644 index 0000000000000..6b01b7d3e1d6c --- /dev/null +++ b/tests/v1/test_kv_sharing.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import Mock + +import torch + +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionBackend, FlashAttentionMetadataBuilder) +from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionBackend, FlexAttentionMetadataBuilder) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec +from vllm.v1.worker.utils import (AttentionGroup, + initialize_kv_cache_for_kv_sharing) + + +def new_kv_cache_spec(): + return FullAttentionSpec(16, 1, 1, torch.float32, False) + + +def test_initialize_kv_cache_for_kv_sharing_different_attn_groups(): + """ + Test initializing KV cache sharing with different attention groups. + Layers in the same KV cache group might be placed in different attn groups + if they have different attention backends. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + # Layers 0 and 1 both belong in KV cache group 0 + # However, if they have have different attention backends, they will be + # placed in different attention groups for KV cache group 0 + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], + new_kv_cache_spec()), + ] + + attn_groups = [ + # KV cache group 0 has two attention groups + [ + AttentionGroup( + backend=FlashAttentionBackend, + metadata_builder=Mock(spec=FlashAttentionMetadataBuilder), + layer_names=["model.layers.0"], + ), + AttentionGroup( + backend=FlexAttentionBackend, + metadata_builder=Mock(spec=FlexAttentionMetadataBuilder), + layer_names=["model.layers.1"], + ), + ], + ] + + # Only layers 0 and 1 will have KV caches allocated + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + attn_groups=attn_groups, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 1 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + # Check that the layers were added to the attention groups + assert len(attn_groups) == 1 and len(attn_groups[0]) == 2 + assert attn_groups[0][0].layer_names == [ + "model.layers.0", "model.layers.2" + ] + assert attn_groups[0][1].layer_names == [ + "model.layers.1", "model.layers.3" + ] + + +def test_initialize_kv_cache_for_kv_sharing_same_attn_groups(): + """ + Test case assuming that all layers in the same KV cache group have the same + attention backends. This is true for most models. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], + new_kv_cache_spec()), + ] + + attn_groups = [ + # KV cache group 0 has a single attention group + # as all layers have the same flash attention backend + [ + AttentionGroup( + backend=FlashAttentionBackend, + metadata_builder=Mock(spec=FlashAttentionMetadataBuilder), + layer_names=["model.layers.0", "model.layers.1"], + ), + ], + ] + + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + attn_groups=attn_groups, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 1 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + # Check that the layers were added to the attention groups + assert len(attn_groups) == 1 and len(attn_groups[0]) == 1 + assert attn_groups[0][0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + +def test_initialize_kv_cache_for_kv_sharing_no_attn_groups(): + """ + Test KV sharing set up when no attention groups are provided. + This is the case for the TPU model runner, which doesn't have + support for attention groups yet. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0"], new_kv_cache_spec()), + KVCacheGroupSpec(["model.layers.1"], new_kv_cache_spec()), + ] + + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 2 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.2" + ] + assert kv_cache_groups[1].layer_names == [ + "model.layers.1", "model.layers.3" + ] diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 0ab4e0bf59cf5..118b40d0ef418 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -11,7 +11,8 @@ import torch from vllm.multimodal.inputs import (MultiModalBatchedField, MultiModalFieldElem, MultiModalFlatField, - MultiModalKwargs, MultiModalKwargsItem, + MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -96,42 +97,10 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): class MyRequest(msgspec.Struct): - mm: Optional[list[MultiModalKwargs]] + mm: Optional[list[MultiModalKwargsItems]] def test_multimodal_kwargs(): - d = { - "foo": - torch.zeros(20000, dtype=torch.float16), - "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], - "baz": [ - torch.rand((256), dtype=torch.float16), - [ - torch.rand((1, 12), dtype=torch.float32), - torch.rand((3, 5, 7), dtype=torch.float64), - ], [torch.rand((4, 4), dtype=torch.float16)] - ], - } - - # pack mm kwargs into a mock request so that it can be decoded properly - req = MyRequest(mm=[MultiModalKwargs(d)]) - - encoder = MsgpackEncoder() - decoder = MsgpackDecoder(MyRequest) - - encoded = encoder.encode(req) - - assert len(encoded) == 6 - - total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - - # expected total encoding length, should be 44559, +-20 for minor changes - assert 44539 <= total_len <= 44579 - decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] - assert all(nested_equal(d[k], decoded[k]) for k in d) - - -def test_multimodal_items_by_modality(): e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()) @@ -151,7 +120,7 @@ def test_multimodal_items_by_modality(): audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) - mm = MultiModalKwargs.from_items([audio, video, image]) + mm = MultiModalKwargsItems.from_seq([audio, video, image]) # pack mm kwargs into a mock request so that it can be decoded properly req = MyRequest([mm]) @@ -165,19 +134,22 @@ def test_multimodal_items_by_modality(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - # expected total encoding length, should be 14255, +-20 for minor changes - assert 14250 <= total_len <= 14300 - decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + # expected total encoding length, should be 14306, +-20 for minor changes + assert 14275 <= total_len <= 14325 + decoded = decoder.decode(encoded).mm[0] + assert isinstance(decoded, MultiModalKwargsItems) # check all modalities were recovered and do some basic sanity checks - assert len(decoded.modalities) == 3 - images = decoded.get_items("image") + assert len(decoded) == 3 + images = decoded["image"] assert len(images) == 1 assert len(images[0].items()) == 2 assert list(images[0].keys()) == ["i0", "i1"] # check the tensor contents and layout in the main dict - assert all(nested_equal(mm[k], decoded[k]) for k in mm) + mm_data = mm.get_data() + decoded_data = decoded.get_data() + assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data) def nested_equal(a: NestedTensors, b: NestedTensors): diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 5a05781a03f2a..941aa0a77692c 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 74ab19a3ce32d..7031859078264 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -13,7 +13,7 @@ from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -169,7 +169,7 @@ def _construct_expected_sampling_metadata( and all(x == 1 for x in repetition_penalties)), allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=bad_words_token_ids, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) @@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int): pooling_params=None, mm_kwargs=[], mm_positions=[], + mm_hashes=[], block_ids=([], ), generator=None, num_computed_tokens=len(output_token_ids), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 4bcc63f293e03..d6cd03fb01a73 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) + kv_cache_config_after_init = runner.kv_cache_config layer_0_kv = vllm_ctx[layer_0].kv_cache[0] layer_1_kv = vllm_ctx[layer_1].kv_cache[0] @@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert id(layer_1_kv) == id(layer_0_kv) # check layer 1 added to kv cache group's layer names - assert len(kv_cache_config.kv_cache_groups) == 1 - assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 - assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + assert len(kv_cache_config_after_init.kv_cache_groups) == 1 + assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 0] == layer_0 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 1] == layer_1 def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1b797074096ed..cc18c9ff1f096 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -26,9 +26,5 @@ compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main -marlin, nm-testing/zephyr-beta-7b-marlin-g128, main -marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main -qqq, HandH1998/QQQ-Llama-3-8b-g128, main -qqq, HandH1998/QQQ-Llama-3-8b, main hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main None, mgleize/fairseq2-dummy-Llama-3.2-1B, main \ No newline at end of file diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index 444e2bf53f995..ad0ae45d1d465 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -37,7 +37,7 @@ ALLOWED_FILES = set([ 'vllm/distributed/utils.py', 'vllm/distributed/parallel_state.py', 'vllm/engine/multiprocessing/client.py', - 'vllm/distributed/device_communicators/custom_all_reduce_utils.py', + 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh index e163c83e8b513..59bfe69dc0dd6 100644 --- a/tools/ep_kernels/install_python_libraries.sh +++ b/tools/ep_kernels/install_python_libraries.sh @@ -77,6 +77,7 @@ clone_repo() { local repo_url=$1 local dir_name=$2 local key_file=$3 + local commit_hash=$4 if [ -d "$dir_name" ]; then # Check if directory has uncommitted changes (dirty) @@ -87,17 +88,27 @@ clone_repo() { echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning" rm -rf "$dir_name" git clone "$repo_url" + if [ -n "$commit_hash" ]; then + cd "$dir_name" + git checkout "$commit_hash" + cd .. + fi else echo "$dir_name directory exists and appears complete; manually update if needed" fi else git clone "$repo_url" + if [ -n "$commit_hash" ]; then + cd "$dir_name" + git checkout "$commit_hash" + cd .. + fi fi } # build and install pplx, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" +clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" "c336faf" cd pplx-kernels # see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 # PIP_NO_BUILD_ISOLATION=0 disables build isolation @@ -106,7 +117,7 @@ popd # build and install deepep, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" +clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "e3908bf" cd DeepEP export NVSHMEM_DIR=$WORKSPACE/nvshmem_install PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh new file mode 100755 index 0000000000000..33849581d2c0e --- /dev/null +++ b/tools/install_deepgemm.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Script to install DeepGEMM from source +# This script can be used both in Docker builds and by users locally + +set -e + +# Default values +DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" +DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --ref) + if [[ -z "$2" || "$2" =~ ^- ]]; then + echo "Error: --ref requires an argument." >&2 + exit 1 + fi + DEEPGEMM_GIT_REF="$2" + shift 2 + ;; + --cuda-version) + if [[ -z "$2" || "$2" =~ ^- ]]; then + echo "Error: --cuda-version requires an argument." >&2 + exit 1 + fi + CUDA_VERSION="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)" + echo " --cuda-version VER CUDA version (auto-detected if not provided)" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac +done + +# Auto-detect CUDA version if not provided +if [ -z "$CUDA_VERSION" ]; then + if command -v nvcc >/dev/null 2>&1; then + CUDA_VERSION=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/p') + echo "Auto-detected CUDA version: $CUDA_VERSION" + else + echo "Warning: Could not auto-detect CUDA version. Please specify with --cuda-version" + exit 1 + fi +fi + +# Extract major and minor version numbers +CUDA_MAJOR="${CUDA_VERSION%%.*}" +CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" +CUDA_MINOR="${CUDA_MINOR%%.*}" + +echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)" + +# Check CUDA version requirement +if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then + echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" + exit 0 +fi + +echo "Installing DeepGEMM from source..." +echo "Repository: $DEEPGEMM_GIT_REPO" +echo "Reference: $DEEPGEMM_GIT_REF" + +# Create a temporary directory for the build +INSTALL_DIR=$(mktemp -d) +trap 'rm -rf "$INSTALL_DIR"' EXIT + +# Clone the repository +git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm" + +echo "🏗️ Building DeepGEMM" +pushd "$INSTALL_DIR/deepgemm" + +# Checkout the specific reference +git checkout "$DEEPGEMM_GIT_REF" + +# Build DeepGEMM +# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) +rm -rf build dist +rm -rf *.egg-info +python3 setup.py bdist_wheel + +# Install the wheel +if command -v uv >/dev/null 2>&1; then + echo "Installing DeepGEMM wheel using uv..." + # Use --system in Docker contexts, respect user's environment otherwise + if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then + uv pip install --system dist/*.whl + else + uv pip install dist/*.whl + fi +else + echo "Installing DeepGEMM wheel using pip..." + python3 -m pip install dist/*.whl +fi + +popd + +echo "✅ DeepGEMM installation completed successfully" \ No newline at end of file diff --git a/tools/profiler/nsys_profile_tools/README.md b/tools/profiler/nsys_profile_tools/README.md new file mode 100644 index 0000000000000..9577efb68fb4b --- /dev/null +++ b/tools/profiler/nsys_profile_tools/README.md @@ -0,0 +1,174 @@ +# gputrc2graph.py + +This script processes NVIDIA Nsight Systems (`nsys`) GPU trace files +(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level +summaries and visualizations of GPU and non-GPU time. It is useful for +profiling and analyzing nsys profile output. + +## Usage + +### Command-line Arguments + +- `--in_file` + **(required)** + List of input files and their metadata. Each entry should be in the format: + `<nsys-rep>,<engine>,<model>,<elapsed_nonprofiled_sec>` + - `nsys-rep`: Path to the `.nsys-rep` file. + - `engine`: Engine name (e.g., `vllm`). + - `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`). + - `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without + profiling. Specify `0` to use the elapsed time from the nsys-rep file + (this may inflate non-GPU time if actual runtime without profiling is + less). Multiple entries can be provided, separated by spaces. + +- `--out_dir` + Output directory for the generated CSV and HTML files. + If not specified, results are saved in the current directory. + +- `--title` + Title for the HTML chart/visualization. + +- `--nsys_cmd` + Path to the `nsys` command. + Default: `nsys` (assumes it is in your PATH). + Use this if `nsys` is not in your system PATH. + +## Notes + +- Make sure you have pandas installed. +- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is installed, and specify the path to the `nsys` command with `--nsys_cmd` if it is not in your PATH. +- For more details on available engines and models, see the help string in + the script or run: + +```bash +python3 gputrc2graph.py --help +``` + +## Example 1: analyze a single profile + +To analyze the GPU cycles for say, gpt-oss model with vLLM engine: + +1. Run the following command to collect nsys profile, for vllm serve config. + + ```bash + nsys profile -t cuda -o run1 -f true --trace-fork-before-exec=true \ + --cuda-graph-trace=node --delay <DELAY> --duration <DURATION> \ + vllm serve openai/gpt-oss-120b ... + ``` + + where: + + - DELAY: how many seconds to delay nsys from collecting profiles, needed so + that profiles aren't captured till vllm server has come up and load + generation starts. + - DURATION: how many seconds for nsys profile to run before generating the + profile. This should be > the duration of the run. + +2. Run again, this time without collecting the profile, and get the total run + time in seconds. This value will be used by the script to calculate the + CPU(non-GPU) seconds for the analysis. + +3. Say the run elapsed time is 306 seconds, from step #2. Run script to + analyze: + + ```bash + python3 gputrc2graph.py \ + --in_file run1.nsys-rep,vllm,gpt-oss,306 \ + --title "vLLM-gpt-oss profile" + ``` + +The command will produce 2 files for analysis: + +- result.html: this categorizes kernel names into different categories in a + stacked bar chart. +- result.csv: shows how the kernel names are mapped to the different + categories. + +### HTML visualization with result.html + +The html file shows the number of elapsed seconds due to different GPU +Substages or categories, which consist of moe_gemm (Mixture of Experts GEMM) +kernels the biggest category, at 148 seconds, followed by "attn" or attention +kernels. This lets the user prioritize the kernels to focus on for performance +optimizations. + +![Example GPU Trace Visualization](images/html.png) + +There's also an appended data table underneath the bar chart for copying out to other post-processing tools. + +![Example GPU Trace Table](images/html_tbl.png) + +### Kernel to category mapping with result.csv + +Suppose the user would like to focus on improving triton kernels. It's not the +biggest consumer of cycles at 9.74 sec but perhaps it hasn't been optimized. +The next step is to use the result.csv to dive into what the kernels are which +compose the triton kernel GPU cycles. The following image shows that +triton_poi_fused__to_copy_add_addmm_cat_.. kernel to be the biggest +contributor to GPU cycles. + +![Example GPU Trace csv](images/csv1.png) + +## Example 2: analyze multiple profiles + +Suppose the user has multiple nsys trace files, captured for different models, +say llama and gpt-oss in this case, and wish to compare their GPU/non-GPU +time, something like the following command can be used. + +```bash +python3 gputrc2graph.py \ +--in_file run1.nsys-rep,vllm,llama,100 run2.nsys-rep,vllm,gpt-oss,102 \ +--out_dir results \ +--title "Comparison of vLLM Models" +``` + +The analysis process is similar to example 1 but now there will be multiple +stack bar charts that can be compared. The categories for the different +kernels will remain the same, so that it's easy to compare the GPU cycles for +the same categories. + +Once a category is shown to have more cycles for one configuration than +another, the next step would be to use the csv file to see what kernels are +mapped into that category, and which kernels are taking the largest amount of +time which would cause a difference for the overall category. + +## Example 3: add new classification for a new model + +To create a new engine DEF with model ABC, just add another json file in the same directory as +gputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications. + +Then, for this new model, suppose there are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels +have names with "*H*" or "*I*" in them, and attn kernels have names with "*J*" +or "*K*" in them, just add another .json file in the same directory as +gputrc2graph.py with the same format as the other json files, like the following: + +```json +{ + "DEF": { + "ABC": { + "H|I": "gemm", + "J|K": "attn", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} +``` + +Each entry in the dictionary consists of: + +- key: a regex used to classify the kernels +- value: the category to classify the kernels into. + +The last 2 entries are common for all engine/models, consisting of CUDA memory +operations and a 'misc' for anything that's leftover and can't be classified. + +When invoking gputrc2graph.py, specify a trace file with this new model/engine +like the following: + +```bash +--infile new.nsys-rep,DEF,ABC,<runtime> +``` + +If the engine_DEF.json file already exists, just add the model as a new node in +the existing engine file, after the other models. diff --git a/tools/profiler/nsys_profile_tools/gputrc2graph.py b/tools/profiler/nsys_profile_tools/gputrc2graph.py new file mode 100755 index 0000000000000..42dfede9e9870 --- /dev/null +++ b/tools/profiler/nsys_profile_tools/gputrc2graph.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" + This generates gpu kernel analysis output from nsys rep. Will call nsys + stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate + csv and html output for analysis +""" +import argparse +import logging +import os + +import regex as re + +logger = logging.getLogger(__name__) + + +# helper data class for annotating kernels +def load_engine_model(): + """ returns engine_model built from all json files in the current dir """ + import glob + import json + engine_model = {} + + json_files = glob.glob( + os.path.join(os.path.dirname(__file__) or ".", "*.json")) + for fname in json_files: + with open(fname, encoding="utf-8") as f: + engine_model.update(json.load(f)) + return engine_model + + +class GPUTrace2Graph: + """ + Parses output of nsys report, generates csv and bar chart output + """ + + def __init__(self): + import pandas as pd # avoid importing till needed + self.pd = pd + self.pd.options.mode.copy_on_write = True + + # helper functions for generating trace->summary csvs + def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file): + logger.info('loading %s', in_file) + df = self.pd.read_csv( + in_file, + usecols=['Start (ns)', 'Duration (ns)', 'Device', 'Strm', 'Name']) + df['End (ns)'] = df['Start (ns)'] + df['Duration (ns)'] + df = self.sum_non_overlapping_intervals(df) + # get ready to print table with elapsed times per kernel + df['Instances'] = 1 + df_sum = df.groupby('Name', as_index=False).agg({ + 'Elapsed Time (ns)': 'sum', + 'Duration (ns)': 'sum', + 'Instances': 'size' + }) + + # generate csv + df_sum['Total Time (sec)'] = df_sum['Duration (ns)'] / 1e9 + df_sum['Elapsed Time (sec)'] = df_sum['Elapsed Time (ns)'] / 1e9 + df_sum = df_sum.sort_values(by='Elapsed Time (sec)', ascending=False) + df_sum[['Elapsed Time (sec)', 'Total Time (sec)', 'Instances', + 'Name']].to_csv(out_file, index=False) + + def sum_non_overlapping_intervals(self, df): + """ + returns new sorted df with Elapsed Time (ns) column using + vectorized operations + """ + logger.info("sorting %s trace records by start time", str(df.shape)) + + # Sort by start time and reset index + df = df.sort_values(by='Start (ns)').reset_index(drop=True) + + # Initialize elapsed time as duration + df['Elapsed Time (ns)'] = df['Duration (ns)'] + + # Get numpy arrays for faster operations + starts = df['Start (ns)'].values + ends = df['End (ns)'].values + + # Keep track of current interval end + current_end = ends[0] + display_units = int(len(df) / 100) + # Update current_end for overlapping intervals + for i in range(1, len(df)): + if i % display_units == 0: + print(f'processing trace: {int(i/len(df) * 100)} %', end="\r") + if starts[i] <= current_end: + if ends[i] > current_end: + # Partial overlap + df.iloc[i, df.columns.get_loc('Elapsed Time (ns)' + )] = ends[i] - current_end + current_end = ends[i] + else: + # Complete overlap + df.iloc[i, df.columns.get_loc('Elapsed Time (ns)')] = 0 + else: + # No overlap + current_end = ends[i] + + return df + + # functions for generating html files + def make_html(self, df, output_dir, title): + """ make html graph from df """ + import plotly.express as px + if df.empty: + return + output_name = output_dir + '/result' + if not title: + title = 'Model_Engine' + x = 'Model_Engine' + y = 'Elapsed Time (sec)' + color = 'Category' + """ generate kernel mapping table """ + # Sort Model_Engine categories by last field after underscore + df['Model_Engine'] = self.pd.Categorical( + df['Model_Engine'], + sorted(df['Model_Engine'].unique(), + key=lambda x: x.split('_')[-1])) + df[['Model_Engine', color, 'Instances', 'Name', + y]].sort_values(by=color).to_csv(f'{output_name}.csv', index=False) + graph = px.histogram(df.round(2), + x=x, + y=y, + title=(f'{y} for {title}'), + color=color, + text_auto=True) + # wrap x axis labels + graph.update_xaxes(automargin=True) + graph.write_html(f'{output_name}.html') + """ + Generate data table with columns per Model_Engine into result.html + """ + pivot_df = df.pivot_table(values='Elapsed Time (sec)', + index='Category', + columns='Model_Engine', + aggfunc='sum', + observed=False).round(2) + # Add sum row at bottom + pivot_df.loc['total_elapsed_sec'] = pivot_df.sum() + pivot_df.fillna('').to_html('temp.html') + with (open(f'{output_name}.html', 'a', encoding='utf-8') as + outfile, open('temp.html', encoding='utf-8') as infile): + outfile.write(infile.read()) + os.remove('temp.html') + + print(f'Finished generating: \n' + f' {output_name}.html for stack bar chart \n' + f' {output_name}.csv for Kernel-Category mapping') + + def anno_gpu_kernname(self, df, mapping): + """ add "Category" column """ + + def anno_gpu_kernname_helper(name): + for kern_name, val in mapping.items(): + if re.search(kern_name, name): + return val + + df['Category'] = df['Name'].apply(anno_gpu_kernname_helper) + + def make_nongpu_row(self, df, nongpu_sec): + """ this will append non-gpu time entry at end of df """ + nongpu_row = self.pd.DataFrame([df.iloc[-1]]) + nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)' + nongpu_row['Instances'] = 1 + nongpu_row['Elapsed Time (sec)'] = nongpu_sec + return (nongpu_row) + + def is_valid_file(self, base_file): + """ asserts if base_file is non-existent or is empty """ + assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, \ + f"{base_file} doesn't exist or is empty" + + def should_gen_file(self, new_file, base_file): + """ figure out if new file should be generated from base_file """ + self.is_valid_file(base_file) + if (os.path.exists(new_file) + and (os.path.getmtime(new_file) > os.path.getmtime(base_file)) + and (os.path.getsize(base_file) > 0)): + logger.info('reusing %s', new_file) + return False + else: + logger.info('generating %s', new_file) + return True + + def gen_sum_file(self, file, nsys_cmd): + """ + generates sum file from nsys trace with times per kernel and + returns the name of the sum file + """ + import subprocess + file_dir = os.path.dirname(file) + file_name = os.path.basename(file) + + if not file_dir: + file_dir = '.' + # Walk through trace and get the total non-overlapped time + nsys_stats_file = f'{file_dir}/{file_name}_cuda_gpu_trace.csv' + sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv' + if self.should_gen_file(nsys_stats_file, file): + cmd = [ + nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o', + f'{file_dir}/{file_name}' + ] + cmd_str = ' '.join(cmd) + logger.info('+ %s', cmd_str) + # estimate time based on calibrated 240M/min + file_size_mb = os.path.getsize(file) / 1e6 + logger.info( + 'nsys stats for %.2f MB file expected to take %.2f min', + file_size_mb, file_size_mb / 240) + try: + subprocess.run(cmd, check=True) + except Exception: + logger.error("%s failed; Use --nsys_cmd to specify nsys path", + cmd_str) + exit(1) + logger.info('generating non-overalapped sum %s', sum_file) + self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file) + self.is_valid_file(sum_file) + logger.info('Finished generating %s', sum_file) + return sum_file + + def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): + """ generates graph and csv file from in_file into out_dir """ + # Initialize an empty DataFrame to store combined data + combined_df = self.pd.DataFrame() + for idx, (file, engine, model, total_sec) in enumerate(in_file): + file_dir = os.path.dirname(file) + file_name = os.path.basename(file) + if not file_dir: + file_dir = '.' + sum_file = self.gen_sum_file(file, nsys_cmd) + # read kernel summary file + df = self.pd.read_csv(sum_file) + # annotate kernel to their categories + assert engine_model.get(engine), f'engine {engine} unknown' + assert engine_model[engine].get(model), f'model {model} unknown' + # remove nsys-rep from file_name for shorter x-label + file_name = file_name.replace('.nsys-rep', '') + df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}' + self.anno_gpu_kernname(df, engine_model[engine][model]) + # patch in non-gpu time + gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1) + total_sec = round(float(total_sec), 1) + if total_sec < gpu_sec: + logger.warning( + "Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec ", + total_sec, + gpu_sec, + ) + total_sec = gpu_sec + nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec) + df = self.pd.concat([df, nongpu_row], ignore_index=True) + combined_df = self.pd.concat([combined_df, df], ignore_index=True) + if out_dir is None: + out_dir = '.' + else: + os.makedirs(out_dir, exist_ok=True) + # generate html file + self.make_html(combined_df, out_dir, title) + + +def parse_tuple(s): + return tuple(s.split(',')) + + +def main(): + logging.basicConfig(format=('%(asctime)s - %(levelname)s - %(message)s'), + level=logging.INFO) + parser = argparse.ArgumentParser( + description=( + 'Process nsys rep and generate kernel non-overlapped cycles. \n' + 'Example:\n' + "gputrc2graph.py --in_file d1.nsys-rep,vllm,llama,100 \n" + "d2.nsys-rep,vllm,gpt-oss,102 " + "--out_dir results/ --title \"Model=gpt-oss vLLM chart\""), + formatter_class=argparse.RawDescriptionHelpFormatter) + + # load supported engine_model + engine_model_supported = load_engine_model() + # Get a string representation of supported engine/model combinations + engine_model_supported_str = ', '.join( + f"{engine}:[{', '.join(models.keys())}]" + for engine, models in engine_model_supported.items()) + parser.add_argument( + '--in_file', + type=parse_tuple, + nargs='+', + help=( + 'list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) ' + 'separated by space. Elapsed_nonprofiled_sec is runtime without ' + 'profiling used to calculate non-gpu time. Specify 0 to use ' + 'elapsed time from nsys-rep but that might inflate non-gpu time. ' + f'Available engine:[model] are: {engine_model_supported_str} ' + f'Example: --infile d1.nsys-rep,vllm,llama,100 ' + 'd2.nsys-rep,vllm,gpt-oss,102'), + required=True) + parser.add_argument('--out_dir', help=('output dir for result.csv/html')) + parser.add_argument('--title', help=('title for html chart')) + parser.add_argument('--nsys_cmd', + help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'), + default="nsys") + args = parser.parse_args() + gputrace = GPUTrace2Graph() + gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd, + engine_model_supported) + + +if __name__ == '__main__': + main() diff --git a/tools/profiler/nsys_profile_tools/images/csv1.png b/tools/profiler/nsys_profile_tools/images/csv1.png new file mode 100644 index 0000000000000..bdeb47c3c2a35 Binary files /dev/null and b/tools/profiler/nsys_profile_tools/images/csv1.png differ diff --git a/tools/profiler/nsys_profile_tools/images/html.png b/tools/profiler/nsys_profile_tools/images/html.png new file mode 100644 index 0000000000000..c3cebdcc9971f Binary files /dev/null and b/tools/profiler/nsys_profile_tools/images/html.png differ diff --git a/tools/profiler/nsys_profile_tools/images/html_tbl.png b/tools/profiler/nsys_profile_tools/images/html_tbl.png new file mode 100644 index 0000000000000..0b47b6f31948e Binary files /dev/null and b/tools/profiler/nsys_profile_tools/images/html_tbl.png differ diff --git a/tools/profiler/nsys_profile_tools/vllm_engine_model.json b/tools/profiler/nsys_profile_tools/vllm_engine_model.json new file mode 100644 index 0000000000000..264c628dded34 --- /dev/null +++ b/tools/profiler/nsys_profile_tools/vllm_engine_model.json @@ -0,0 +1,63 @@ +{ + "vllm": { + "llama": { + "fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal": "moe_gemm", + "gemm|nvjet": "gemm", + "moe|sigmoid": "moe", + "CatArrayBatched|prepare_inputs": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "_norm_|Norm": "norm", + "act_and_mul_": "activation", + "Rotary": "rope", + "SoftMax": "softmax", + "flash|fmha": "attn", + "elementwise": "elementwise", + "fp8_quant|cvt_": "quantize", + "reduce_kernel": "reduce", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "ds": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "sbtopk": "topk", + "act_and_mul_": "activation", + "compute_position_kernel": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "_fwd_|FlashAttn|_mla_|_attn_|fmha": "attn", + "triton": "triton_kernel", + "topk": "topk", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "gpt-oss": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert|splitKreduce": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "topk": "topk", + "act_and_mul_": "activation", + "compute_position_kernel": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_|quantize": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha": "attn", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a318637c5aeba..054dc9d985a4c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -387,14 +387,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) -# marlin -def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) - - # marlin_24 def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -437,25 +429,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::marlin_qqq_gemm") - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake("_C::marlin_gemm") - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: torch.SymInt, @@ -476,32 +449,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): dtype=input.dtype, device=input.device).sum(0) - @register_fake("_C::aqlm_gemm") - def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, - codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: list[int], - bias: Optional[torch.Tensor]) -> torch.Tensor: - out_features = codes.size(0) * codebooks.size(2) - flat_input = input.reshape((-1, input.size(-1))) - flat_output = torch.empty((flat_input.size(0), out_features), - dtype=input.dtype, - device=input.device) - - output_sizes = list(input.shape) - output_sizes.pop() - output_sizes.append(-1) - return flat_output.reshape(tuple(output_sizes)) - - @register_fake("_C::aqlm_dequant") - def _aqlm_dequant_fake( - codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: list[int]) -> torch.Tensor: - in_features = codes.size(1) * 8 - out_features = codes.size(0) - return torch.empty((out_features, in_features), - dtype=codebooks.dtype, - device=codebooks.device) - @register_fake("_C::machete_mm") def machete_mm_fake( a: torch.Tensor, @@ -527,6 +474,30 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) + @register_fake("_C::cutlass_w4a8_mm") + def cutlass_w4a8_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + out_dtype = out_type if out_type is not None else torch.bfloat16 + return torch.empty((m, n), device=a.device, dtype=out_dtype) + + @register_fake("_C::cutlass_pack_scale_fp8") + def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: + return torch.empty_like(scales, memory_format=torch.contiguous_format) + + @register_fake("_C::cutlass_encode_and_reorder_int4b") + def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) + if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @@ -870,6 +841,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, blockscale_offsets) +def get_cutlass_moe_mm_problem_sizes( + topk_ids: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None): + """ + Compute only the per-expert problem sizes needed by the two grouped matrix + multiplications used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token→expert mapping) and computes: + - problem_sizes1, problem_sizes2: M×N×K sizes of each expert's + multiplication for the two grouped MMs + used in the fused MoE operation. + """ + return torch.ops._C.get_cutlass_moe_mm_problem_sizes( + topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, + blockscale_offsets) + + def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): """ Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. @@ -957,21 +950,6 @@ def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, sf_offsets) -# aqlm -def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, - codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: list[int], - bias: Optional[torch.Tensor]) -> torch.Tensor: - return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, - codebook_partition_sizes, bias) - - -def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: list[int]) -> torch.Tensor: - return torch.ops._C.aqlm_dequant(codes, codebooks, - codebook_partition_sizes) - - # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, @@ -1078,6 +1056,30 @@ def machete_prepack_B( group_scales_type) +# CUTLASS W4A8 +def cutlass_w4a8_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size, + b_channel_scales, a_token_scales, + out_type, maybe_schedule) + + +def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_pack_scale_fp8(scales) + + +def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_encode_and_reorder_int4b(b) + + if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") @@ -1367,15 +1369,6 @@ def scaled_int8_quant( return output, input_scales, input_azp -# qqq ops -def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - # gguf def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, dtype: Optional[torch.dtype]) -> torch.Tensor: @@ -1509,6 +1502,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, gating_output) +def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor, + num_expert_group: int, topk_group: int, topk: int, + renormalize: bool, routed_scaling_factor: float): + if not current_platform.is_cuda(): + raise NotImplementedError("The fused grouped_topk kernel is only " + "available on CUDA platforms") + return torch.ops._moe_C.grouped_topk(scores, scores_with_bias, + num_expert_group, topk_group, topk, + renormalize, routed_scaling_factor) + + def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], b_qweight: torch.Tensor, b_bias: Optional[torch.Tensor], @@ -1644,14 +1648,18 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) -def gather_cache(src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - seq_starts: Optional[torch.Tensor] = None) -> None: - torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size, seq_starts) +def gather_and_maybe_dequant_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + kv_cache_dtype: str, + scale: torch.Tensor, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( + src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, seq_starts) def get_device_attribute(attribute: int, device: int) -> int: @@ -1882,3 +1890,86 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): M = mat1.size(0) N = mat2.size(0) return torch.empty((M, N), dtype=out_dtype) + + +class CPUDNNLGEMMHandler: + + def __init__(self) -> None: + self.handler: Optional[int] = None + self.n = -1 + self.k = -1 + + def __del__(self): + if self.handler is not None: + torch.ops._C.release_dnnl_matmul_handler(self.handler) + + +def create_onednn_scaled_mm( + weight: torch.Tensor, # [K, N] + weight_scales: torch.Tensor, + output_type: torch.dtype, + dynamic_quant: bool, + use_azp: bool, + primitive_cache_size: int = 128, +) -> CPUDNNLGEMMHandler: + handler = CPUDNNLGEMMHandler() + handler.k, handler.n = weight.size() + handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( + weight, weight_scales, output_type, dynamic_quant, use_azp, + primitive_cache_size) + return handler + + +def onednn_scaled_int8_quant(input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True): + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + token_num = input.numel() // input.shape[-1] + input = input.view((token_num, input.shape[-1])) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp + is None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((token_num, 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp + + +def onednn_scaled_mm( + dnnl_handler: CPUDNNLGEMMHandler, + x: torch.Tensor, + output: torch.Tensor, + input_scale: Optional[torch.Tensor], + input_zp: Optional[torch.Tensor], + input_zp_adj: Optional[torch.Tensor], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + torch.ops._C.onednn_scaled_mm(output, x, input_scale, input_zp, + input_zp_adj, bias, dnnl_handler.handler) + + return output diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 344040586a532..dcb2aa68fbee9 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -14,7 +14,6 @@ __all__ = [ "AttentionMetadata", "AttentionType", "AttentionMetadataBuilder", - "Attention", "AttentionState", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index d21f07756871a..0b9c625533cb7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: @@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it. - TODO(luka) merge parameters into QuantDescriptor - :param dtype: quantized dtype - :param static: static or dynamic quantization - :param group_shape: quant group shape. + :param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization """ return False @@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index fac3c318a87a0..caa02530d2fd6 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -800,23 +800,33 @@ class DifferentialFlashAttentionImpl(AttentionImpl): attn_metadata: DifferentialFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + layer: Attention layer instance. + q: Query tensor with shape = [num_tokens, num_heads, head_size] + k: Key tensor with shape = [num_tokens, num_kv_heads, head_size] + v: Value tensor with shape = [num_tokens, num_kv_heads, head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size, num_kv_heads, head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. + output: Output tensor with shape [num_tokens, num_heads, head_size] + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. NOTE: It in-place updates the output tensor. NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for DifferentialFlashAttentionImpl") + if self.lambda_full is None: self.lambda_init = self.differential_flash_attention_config[ "lambda_init"] diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index fa6f3f1b39cca..85957bea1e26d 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): attn_metadata: DualChunkFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with DualChunkFlashAttention. Args: @@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): """ assert output is None, "Output tensor not supported for DualChunk" - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e52480d5c5ce2..d8cb208c4f2ea 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -604,7 +605,8 @@ class FlashAttentionImpl(AttentionImpl): key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size, num_kv_heads, head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. @@ -615,7 +617,7 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") @@ -849,7 +851,7 @@ class FlashAttentionImpl(AttentionImpl): def _get_query_key_seq_metadata( - attn_metadata, + attn_metadata: FlashAttentionMetadata, is_prompt: bool, attn_type: str, ) -> tuple: diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py deleted file mode 100644 index 208cacec38eb5..0000000000000 --- a/vllm/attention/backends/flashinfer.py +++ /dev/null @@ -1,1098 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type - -from vllm.multimodal import MultiModalPlaceholderMap - -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import (CUDAGraphBatchDecodeWithPagedKVCacheWrapper, - trtllm_batch_decode_with_kv_cache) - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - - from vllm.vllm_flash_attn import flash_attn_varlen_func - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - # Avoid turning these types into variables during type checking - if not TYPE_CHECKING: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - trtllm_batch_decode_with_kv_cache = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - raise ImportError("FlashInfer is not installed. Please install it from " - "https://github.com/flashinfer-ai/flashinfer") from None - -import torch - -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.attention.layer import Attention -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.logger import init_logger -from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, - make_tensor_with_pad) -from vllm.utils.flashinfer import use_trtllm_attention - -logger = init_logger(__name__) - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - - -class FlashInferBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "FLASHINFER" - - @staticmethod - def get_impl_cls() -> Type["FlashInferImpl"]: - return FlashInferImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return FlashInferMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: - return FlashInferMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["FlashInferState"]: - return FlashInferState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - @staticmethod - def get_kv_cache_stride_order() -> Tuple[int, ...]: - cache_layout = FlashInferState.get_kv_cache_layout() - assert (cache_layout in ("NHD", "HND")) - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, - 2, 4) - return stride_order - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [64, 128, 256] - - @staticmethod - def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - return torch.float8_e4m3fn - elif kv_cache_dtype == "fp8_e5m2": - return torch.float8_e5m2 - else: - raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") - - -@dataclass -class PerLayerParameters: - """ - Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters. - """ - - window_left: int - logits_soft_cap: Optional[float] - sm_scale: float - - -def get_per_layer_parameters( - vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: - """ - Scan all attention layers and determine some hyperparameters - to use during `plan`. - """ - - layers = get_layers_from_vllm_config(vllm_config, Attention) - per_layer_params: Dict[str, PerLayerParameters] = {} - - for key, layer in layers.items(): - impl = layer.impl - assert isinstance(impl, FlashInferImpl) - - # Infer hyperparameters from the attention layer - window_size = impl.sliding_window - window_left = window_size[0] if window_size is not None else -1 - logits_soft_cap = impl.logits_soft_cap - sm_scale = impl.scale - - per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale) - - return per_layer_params - - -def infer_global_hyperparameters( - per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: - """ - Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters: - - `window_left` - - `logits_soft_cap` - - `sm_scale` - - So this function asserts that all layers share the same values for these - hyperparameters and returns the global values. - """ - - assert len(per_layer_params) > 0, "No attention layers found in the model." - - param_sets = list(per_layer_params.values()) - global_params = param_sets[0] - for params in param_sets: - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all " - "layers share the same values for the following hyperparameters: " - "`window_left`, `logits_soft_cap`, `sm_scale`.") - - return global_params - - -class FlashInferState(AttentionState): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - self._workspace_buffer = None - self._decode_wrapper = None - self._prefill_wrapper = None - - # Global hyperparameters shared by all attention layers - self.global_hyperparameters: Optional[PerLayerParameters] = None - - self.vllm_config = self.runner.vllm_config - self._kv_cache_layout = None - - def _get_workspace_buffer(self): - if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.runner.device) - return self._workspace_buffer - - @staticmethod - def get_kv_cache_layout(): - from vllm.v1.attention.backends.utils import _KV_CACHE_LAYOUT_OVERRIDE - if _KV_CACHE_LAYOUT_OVERRIDE is not None: - logger.info_once("Using KV cache layout %s", - _KV_CACHE_LAYOUT_OVERRIDE) - return _KV_CACHE_LAYOUT_OVERRIDE - cache_layout = envs.VLLM_KV_CACHE_LAYOUT - if cache_layout is None: - logger.info_once("Using default KV cache layout NHD") - return "NHD" - logger.info_once("Using KV cache layout %s", cache_layout) - return cache_layout - - def _get_prefill_wrapper(self): - if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), self.get_kv_cache_layout()) - return self._prefill_wrapper - - def _get_decode_wrapper(self): - if self._decode_wrapper is None: - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) - num_kv_heads = self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self._get_workspace_buffer(), - self.get_kv_cache_layout(), - use_tensor_cores=use_tensor_cores) - return self._decode_wrapper - - @contextmanager - def graph_capture(self, max_batch_size: int): - self._is_graph_capturing = True - self._graph_decode_wrapper = None - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - self._graph_decode_workspace_buffer = self._get_workspace_buffer() - self._graph_indices_buffer = torch.empty( - max_batch_size * self.runner.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.runner.device) - self._graph_indptr_buffer = torch.empty(max_batch_size + 1, - dtype=torch.int32, - device=self.runner.device) - self._graph_last_page_len_buffer = torch.empty( - max_batch_size, dtype=torch.int32, device=self.runner.device) - yield - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - del self._graph_decode_workspace_buffer - del self._graph_indices_buffer - del self._graph_indptr_buffer - del self._graph_last_page_len_buffer - del self._graph_decode_wrapper - - def graph_clone(self, batch_size: int): - assert self._is_graph_capturing - state = self.__class__(self.runner) - state._workspace_buffer = self._graph_decode_workspace_buffer - state._decode_wrapper = self._graph_decode_wrapper - state._prefill_wrapper = self._get_prefill_wrapper() - return state - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - assert self._is_graph_capturing - _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] - _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] - - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) - num_kv_heads = self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - self._graph_decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - self._graph_decode_workspace_buffer, _indptr_buffer, - self._graph_indices_buffer, _last_page_len_buffer, - self.get_kv_cache_layout(), - use_tensor_cores) - if self.runner.kv_cache_dtype.startswith("fp8"): - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.runner.kv_cache_dtype) - else: - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) - - paged_kv_indptr_tensor_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - paged_kv_indices_tensor_host = torch.arange(0, - batch_size, - dtype=torch.int32) - paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), - self.runner.block_size, - dtype=torch.int32) - query_start_loc_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - - global_params = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) - - attn_metadata = self.runner.attn_backend.make_metadata( - num_prefills=0, - slot_mapping=self._graph_slot_mapping[:batch_size], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - max_prefill_seq_len=0, - max_decode_seq_len=0, - seq_lens_tensor=self._graph_seq_lens, - block_tables=self._graph_block_tables, - paged_kv_indptr=paged_kv_indptr_tensor_host, - paged_kv_indices=paged_kv_indices_tensor_host, - paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=self.runner.model_config.get_head_size(), - page_size=self.runner.block_size, - seq_start_loc=None, - query_start_loc=query_start_loc_host, - device=self.runner.device, - data_type=kv_cache_dtype, - q_data_type=self.runner.model_config.dtype, - use_cuda_graph=True, - decode_wrapper=self._graph_decode_wrapper, - prefill_wrapper=None, - **dataclasses.asdict(global_params), - ) - attn_metadata.begin_forward() - return attn_metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - return { - "block_tables": attn_metadata.block_tables, - "seq_lens_tensor": attn_metadata.seq_lens_tensor, - "slot_mapping": attn_metadata.slot_mapping, - } - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - # FlashInfer-specific logic: copy additional tensors - num_total_blocks = attn_metadata.decode_metadata.seq_lens_tensor.shape[ - 0] - input_buffers["seq_lens_tensor"][:num_total_blocks].copy_( - attn_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"][:num_total_blocks].copy_( - attn_metadata.block_tables, non_blocking=True) - - def begin_forward(self, model_input): - assert not self._is_graph_capturing - state = self - use_cuda_graph = model_input.attn_metadata.use_cuda_graph - is_decode = model_input.attn_metadata.num_prefills == 0 - # In case of multistep chunked-prefill, there might be prefill requests - # scheduled while CUDA graph mode is enabled. We don't run graph in that - # case. - if use_cuda_graph and is_decode: - if model_input.inputs_embeds is None: - batch_size = model_input.input_tokens.shape[0] - state = ( - self.runner.graph_runners[model_input.virtual_engine][( - batch_size, False)].attn_state) - else: - batch_size = model_input.inputs_embeds.shape[0] - state = ( - self.runner.graph_runners[model_input.virtual_engine][( - batch_size, True)].attn_state) - - model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( - ) - model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() - model_input.attn_metadata.begin_forward() - - -@dataclass -class FlashInferMetadata(AttentionMetadata): - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - max_decode_seq_len: int - - # Number of query tokens for each request in the batch. - # Currently, we require that all requests have the same number of query - # tokens during the decoding phase. When speculavie decoding is enabled, - # decode_query_len might be greater than 1. In all other cases, it is 1. - decode_query_len: Optional[int] = 1 - - use_cuda_graph: bool = True - - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None - - # Metadata for the prefill stage - seq_start_loc: Optional[torch.Tensor] = None - query_start_loc: Optional[torch.Tensor] = None - block_tables: Optional[torch.Tensor] = None - - # used for GPU operations - seq_lens_tensor: Optional[torch.Tensor] = None - block_table_bound: Optional[torch.Tensor] = None - - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None - # The page indices of the paged kv cache - paged_kv_indices: Optional[torch.Tensor] = None - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: Optional[torch.Tensor] = None - # The number of query/output heads - num_qo_heads: Optional[int] = None - # The number of key/value heads - num_kv_heads: Optional[int] = None - # The dimension of the attention heads - head_dim: Optional[int] = None - # Block size of vllm - page_size: Optional[int] = None - # The data type of the paged kv cache - data_type: torch.dtype = None - # The data type of the query - q_data_type: torch.dtype = None - # FlashInfer 0.2 encourages passing host tensors - device: torch.device = torch.device("cpu") - is_profile_run: bool = False - - # The FlashInfer backend currently supports only models in which all layers - # share the same following hyperparameters: - - # The left (inclusive) window size for the attention window, when - # set to `-1`, the window size will be set to the full length of - # the sequence. Defaults to `-1`. - window_left: int = -1 - # The attention logits soft capping value (used in Gemini, Grok and - # Gemma-2, etc.), if not provided, will be set to `0`. If greater - # than 0, the logits will be capped according to formula: - # $$\texttt{logits\_soft\_cap} \times - # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, - # where $x$ is the input logits. - logits_soft_cap: Optional[float] = None - # The scale used in softmax, if not provided, will be set to - # `1.0 / sqrt(head_dim)`. - sm_scale: Optional[float] = None - - def __post_init__(self): - # Refer to - # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - supported_head_sizes = FlashInferBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f" received {self.head_dim}.") - - def begin_forward(self): - if self.num_prefill_tokens > 0: - if self.paged_kv_indices is None: - return - - assert self.prefill_wrapper is not None - assert self.query_start_loc is not None - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - assert self.block_table_bound is not None - assert self.seq_lens_tensor is not None - self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] - batch_size = self.query_start_loc.shape[0] - 1 - assert batch_size >= 0 - # We will use flash attention for profiling to - # determine the number of blocks. Therefore, - # we don't need to prepare the input for flashinfer for profile run. - if not self.is_profile_run: - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) - self.block_table_bound = self.block_table_bound.to(self.device) - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.plan( - self.query_start_loc, - self.paged_kv_indptr[:self.num_prefills + 1], - self.paged_kv_indices, - self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.data_type) - if self.num_decode_tokens > 0: - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) - # handle model warmup path - if self.block_table_bound is not None: - self.block_table_bound = self.block_table_bound.to(self.device) - if self.seq_lens_tensor is not None: - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - - assert self.decode_wrapper is not None - self.decode_wrapper.plan( - self.paged_kv_indptr[self.num_prefills:], - self.paged_kv_indices, - self.paged_kv_last_page_len[self.num_prefills:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - sm_scale=self.sm_scale, - # kv-cache data type. - kv_data_type=self.data_type, - # query data type. - q_data_type=self.q_data_type) - - def asdict_zerocopy(self, - skip_fields: Optional[Set[str]] = None - ) -> Dict[str, Any]: - if skip_fields is None: - skip_fields = set() - # We need to skip the prefill/decode_wrapper field since it cannot be - # broadcasted with nccl when TP is enabled. - skip_fields.add('prefill_wrapper') - skip_fields.add('decode_wrapper') - return super().asdict_zerocopy(skip_fields) - - @property - def prefill_metadata(self) -> Optional["FlashInferMetadata"]: - if self.num_prefills == 0: - return None - return self - - @property - def decode_metadata(self) -> Optional["FlashInferMetadata"]: - if self.num_decode_tokens == 0: - return None - return self - - -class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - # Global hyperparameters shared by all attention layers - self.global_hyperparameters: Optional[PerLayerParameters] = None - - self.vllm_config = self.runner.vllm_config - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - self.paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - self.paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request - self.paged_kv_last_page_len: List[int] = [] - self.total_blocks = 0 - self.is_profile_run: bool = False - - if self.global_hyperparameters is None: - # Infer global hyperparameters, since currently we only support - # models in which all layers share the same values for the - # following hyperparameters: - # - `window_left` - # - `logits_soft_cap` - # - `sm_scale` - inferred_params = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) - self.global_hyperparameters = inferred_params - self.window_left = inferred_params.window_left - self.logits_soft_cap = inferred_params.logits_soft_cap - self.sm_scale = inferred_params.sm_scale - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - computed_block_nums = inter_data.computed_block_nums - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if inter_data.prefix_cache_hit: - block_table = computed_block_nums - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - block_table = block_tables[seq_id][-curr_sliding_window_block:] - self.block_tables.append(block_table) - - is_profile_run = is_block_tables_empty(block_tables) - - # Compute slot mapping. - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - # It is not necessary to add paged_kv_indices, paged_kv_indptr, - # and paged_kv_last_page_len for profile run because we will - # create dummy inputs. - if is_profile_run: - self.is_profile_run = is_profile_run - return - - block_table = block_tables[seq_id] - self._update_paged_kv_tensors(block_table, seq_len) - - def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - self.total_blocks += len(block_table) - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - self.paged_kv_last_page_len.append(last_page_len) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - decode_query_len = max(query_lens[self.num_prefills:], default=1) - - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.runner.graph_block_tables[:batch_size] - max_blocks = input_block_tables.shape[1] - for i, block_table in enumerate(self.block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - input_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - input_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - block_tables = torch.from_numpy(input_block_tables).to( - device, non_blocking=True) - - last_paged_kv_indptr = self.paged_kv_indptr[-1] - self.paged_kv_indptr.extend([last_paged_kv_indptr] * - cuda_graph_pad_size) - self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - - assert device is not None - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - if len(self.paged_kv_indptr) > 0: - # extend to the maximum number of blocks as returned by the - # scheduler - self.paged_kv_indices.extend( - [0] * (self.total_blocks - len(self.paged_kv_indices))) - paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, - device="cpu", - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, - device="cpu", - dtype=torch.int) - paged_kv_last_page_len_tensor = torch.tensor( - self.paged_kv_last_page_len, device="cpu", dtype=torch.int) - block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - - 1, - device="cpu", - dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_len_tensor = None - block_table_bound_tensor = None - - if self.runner.kv_cache_dtype.startswith("fp8"): - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.runner.kv_cache_dtype) - else: - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) - - return FlashInferMetadata( - decode_query_len=decode_query_len, - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=False, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor, - paged_kv_indices=paged_kv_indices_tensor, - paged_kv_last_page_len=paged_kv_last_page_len_tensor, - block_table_bound=block_table_bound_tensor, - seq_lens_tensor=seq_lens_tensor, - num_qo_heads=self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config), - num_kv_heads=self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config), - head_dim=self.runner.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=seq_start_loc, - query_start_loc=query_start_loc, - device=device, - data_type=kv_cache_dtype, - q_data_type=self.runner.model_config.dtype, - use_cuda_graph=use_captured_graph, - is_profile_run=self.is_profile_run, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - sm_scale=self.sm_scale, - ) - - -class FlashInferImpl(AttentionImpl): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "FLASHINFER backend.") - if use_irope: - logger.warning_once( - "Using irope in FlashInfer is not supported yet, it will fall" - " back to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - self.logits_soft_cap = logits_soft_cap - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl") - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashInferMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashInferImpl") - - # TODO: directly write to output tensor - num_heads: int = self.num_heads - head_size: int = self.head_size - num_kv_heads: int = self.num_kv_heads - kv_cache_dtype: str = self.kv_cache_dtype - softmax_scale: float = self.scale - window_size = self.sliding_window - alibi_slopes = self.alibi_slopes - logits_soft_cap = self.logits_soft_cap - - num_tokens, hidden_size = query.shape - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - if kv_cache.numel() > 0: - # Use the same reshape and cache kernel as flash attention. - ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 - if kv_cache_dtype.startswith("fp8"): - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - kv_cache_dtype) - kv_cache = kv_cache.view(torch_dtype) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - query = query.contiguous( - ) # Flashinfer requires query to be contiguous - # Query for decode. KV is not needed because it is already cached. - # QKV for prefill. - decode_query = query[num_prefill_tokens:] - query = query[:num_prefill_tokens] - - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - window_left = window_size[0] if window_size is not None else -1 - - prefill_output: Optional[torch.Tensor] = None - if num_decode_tokens > 0: - decode_output = torch.empty(decode_query.shape, - dtype=decode_query.dtype, - device=decode_query.device) - else: - decode_output = None - stride_order = FlashInferBackend.get_kv_cache_stride_order() - if prefill_meta := attn_metadata.prefill_metadata: - # We will use flash attention for prefill - # when kv_cache is not provided. - # This happens when vllm runs the profiling to - # determine the number of blocks. - if kv_cache.numel() == 0: - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - ) - else: - assert prefill_meta is not None - assert prefill_meta.prefill_wrapper is not None - - assert prefill_meta.prefill_wrapper._causal - assert prefill_meta.prefill_wrapper._window_left == window_left - assert prefill_meta.prefill_wrapper._logits_soft_cap == ( - logits_soft_cap or 0.0) - assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale - - prefill_output = prefill_meta.prefill_wrapper.run( - query, - kv_cache.permute(*stride_order), - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - ) - if decode_meta := attn_metadata.decode_metadata: - assert decode_meta is not None - assert decode_meta.decode_wrapper is not None - - assert decode_meta.decode_wrapper._window_left == window_left - assert decode_meta.decode_wrapper._logits_soft_cap == ( - logits_soft_cap or 0.0) - assert decode_meta.decode_wrapper._sm_scale == softmax_scale - # TODO: @pavanimajety Remove this once the switch happens - # inside flashinfer. - if not use_trtllm_attention( - num_decode_tokens, attn_metadata.max_decode_seq_len, - kv_cache_dtype, attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, attn_metadata.head_dim): - decode_meta.decode_wrapper.run( - decode_query, - kv_cache.permute(*stride_order), - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=decode_output, - ) - else: - workspace_buffer = ( - decode_meta.decode_wrapper._float_workspace_buffer) - assert FlashInferState.get_kv_cache_layout() == "HND" - trtllm_batch_decode_with_kv_cache( - query=decode_query, - kv_cache=kv_cache.permute(*stride_order), - workspace_buffer=workspace_buffer, - block_tables=attn_metadata.block_tables, - seq_lens=decode_meta.seq_lens_tensor, - max_seq_len=attn_metadata.max_decode_seq_len, - bmm1_scale=layer._k_scale_float * softmax_scale, - bmm2_scale=layer._v_scale_float, - out=decode_output, - ) - - if prefill_output is None and decode_output is not None: - # Decode only batch. - output, num_tokens = decode_output, num_decode_tokens - elif decode_output is None and prefill_output is not None: - # Prefill only batch. - output, num_tokens = prefill_output, num_prefill_tokens - else: - # Chunked prefill batch does not work with speculative decoding in - # FlashInfer backend, so the query length for decode should be 1. - assert prefill_output is not None - assert decode_output is not None - assert decode_meta is not None - assert decode_meta.decode_query_len == 1 - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 8ff7f56743230..c5ed4c6e40326 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -837,8 +837,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): self.context_chunk_workspace_size // num_prefills_with_context # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel cannot + # handle `context_chunk_starts` that are not aligned to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) @@ -1082,6 +1082,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): prefill_metadata = attn_metadata.prefill_metadata assert prefill_metadata is not None @@ -1103,12 +1104,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): for i in range(iters): toks = prefill_metadata.context_chunk_seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_tables, cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], batch_size=prefill_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.context_chunk_starts[i], ) @@ -1165,6 +1168,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: prefill_metadata = attn_metadata.prefill_metadata @@ -1197,7 +1201,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): # ROCm flash_attn_varlen_func will return 3 objects instead of 2 suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1234,12 +1238,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: raise NotImplementedError( "output is not yet supported for MLAImplBase") - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLAImplBase") @@ -1287,7 +1292,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: decode_q_nope, decode_q_pe = decode_q.split( diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 63e467f5a7a22..9262144e37b54 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -20,7 +20,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -529,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim)) - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): if self.use_triton_flash_attn: - return dtype == current_platform.fp8_dtype( - ) and static and group_shape == GroupShape.PER_TENSOR + return quant_key == kFp8StaticTensorSym # Only supported in the Triton backend return False @@ -548,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): attn_metadata: ROCmFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -585,17 +584,18 @@ class ROCmFlashAttentionImpl(AttentionImpl): use prefill sequence attributes Args: + layer: Attention layer instance. query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size * num_kv_heads * head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally + output: Optional output tensor. + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. Returns: shape = [num_tokens, num_heads * head_size] """ @@ -606,6 +606,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): "fused output quantization only supported for Triton" " implementation in ROCMFlashAttentionImpl for now") + if output_block_scale is not None: + raise NotImplementedError( + "fused nvfp4 output quantization is not supported" + " for ROCMFlashAttentionImpl") + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 34e059067d84d..7b6c426b0f851 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -561,7 +561,7 @@ def get_num_prefill_decode_query_kv_tokens( Raises: AssertionError: If the number of encoder tokens in `attn_metadata` - is `None` when required for the calculations. + is `None` when required for the calculations. """ num_prefill_query_tokens = 0 num_decode_query_tokens = 0 diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0bc38b4142901..302d3d7ea903f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): attn_metadata: "XFormersMetadata", output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -470,21 +471,22 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): max_encoder_seq_len) Args: + layer: Attention layer instance. query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size * num_kv_heads * head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally + output: Optional output tensor. + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersImpl") @@ -643,7 +645,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): for API spec. Args: - output: shape = [num_prefill_tokens, num_heads, head_size] query: shape = [num_prefill_tokens, num_heads, head_size] key: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1a9c0e26b53ca..237802afccde9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -54,7 +55,7 @@ def check_xformers_availability(): return USE_XFORMERS_OPS -class Attention(nn.Module): +class Attention(nn.Module, AttentionLayerBase): """Attention layer. This class takes query, key, and value tensors as input. The input tensors @@ -128,11 +129,17 @@ class Attention(nn.Module): self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - # We also keep the float32 versions of k/v_scale for attention - # backends that don't support tensors (Flashinfer) + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + self._o_scale_float: Optional[float] = None + self.use_mla = use_mla self.num_heads = num_heads self.head_size = head_size @@ -183,8 +190,7 @@ class Attention(nn.Module): # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( - ) and not current_platform.is_cpu() + self.use_direct_call = not current_platform.opaque_attention_op() self.use_output = self.attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config @@ -291,6 +297,7 @@ class Attention(nn.Module): self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() # We only calculate the scales once @@ -308,6 +315,15 @@ class Attention(nn.Module): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) + # FlashInfer requires attention sinks to be float32 + if (self.backend == _Backend.FLASHINFER_VLLM_V1 + and hasattr(self.impl, 'sinks')): + from vllm.v1.attention.backends.flashinfer import FlashInferImpl + assert isinstance(self.impl, FlashInferImpl) + if (self.impl.sinks is not None + and self.impl.sinks.dtype != torch.float32): + self.impl.sinks = self.impl.sinks.to(torch.float32) + def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend @@ -479,6 +495,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -494,7 +511,8 @@ def unified_attention_with_output( kv_cache, attn_metadata, output=output, - output_scale=output_scale) + output_scale=output_scale, + output_block_scale=output_block_scale) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -506,6 +524,7 @@ def unified_attention_with_output_fake( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: return @@ -513,7 +532,7 @@ def unified_attention_with_output_fake( direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, - mutates_args=["output"], + mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 892077ba91e07..087c5004bde06 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -6,12 +6,13 @@ from typing import List, Optional import torch from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, QuantizationConfig from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, make_local_attention_virtual_batches, - subclass_attention_backend, subclass_attention_metadata_builder) + subclass_attention_backend) from ..layer import Attention @@ -24,21 +25,23 @@ def create_chunked_local_attention_backend( ) -> type[AttentionBackend]: prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" - def build_preprocess_fn(cm: CommonAttentionMetadata): - return make_local_attention_virtual_batches(attention_chunk_size, cm, - block_size) + underlying_builder = underlying_attn_backend.get_builder_cls() + + class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + common_attn_metadata = make_local_attention_virtual_batches( + attention_chunk_size, common_attn_metadata, block_size) + return super().build(common_prefix_len, common_attn_metadata, + fast_build) - # Dynamically create a new attention backend that wraps the - # underlying attention backend but applies - # `make_local_attention_virtual_batches` before calling `build(...)` - builder_cls = subclass_attention_metadata_builder( - name_prefix=prefix, - builder_cls=underlying_attn_backend.get_builder_cls(), - build_preprocess_fn=build_preprocess_fn) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=builder_cls) + builder_cls=ChunkedLocalAttentionBuilder) return attn_backend diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py new file mode 100644 index 0000000000000..cea05df5b96d2 --- /dev/null +++ b/vllm/attention/layers/encoder_only_attention.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + subclass_attention_backend) + + +@functools.lru_cache +def create_encoder_only_attention_backend( + underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + prefix = "EncoderOnlyAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_common_attn_metadata = copy(common_attn_metadata) + new_common_attn_metadata.causal = False + return super().build(common_prefix_len, new_common_attn_metadata, + fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=EncoderOnlyAttentionBuilder) + + return attn_backend + + +class EncoderOnlyAttention(Attention): + """ + Encoder attention is a special case that doesn't need a KV Cache. + """ + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_encoder_only_attention_backend( + underlying_attn_backend) + else: + # in v0 encoder only attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_ONLY, \ + "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" + + super().__init__(num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_ONLY, + **kwargs) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 1af26dfc3daa3..564042cf8eb12 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -67,6 +67,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -81,6 +83,8 @@ def flash_mla_with_kvcache( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. + descale_q: (batch_size), torch.float32. Descaling factors for Q. + descale_k: (batch_size), torch.float32. Descaling factors for K. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). @@ -98,6 +102,8 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + descale_q, + descale_k, ) return out, softmax_lse diff --git a/vllm/beam_search.py b/vllm/beam_search.py index f3bc4218323d8..5a2e79e1b5c74 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -18,7 +18,7 @@ class BeamSearchSequence: The text field is optional and will only be filled when the sequence is about to be returned to the user. """ - # The tokens includes the prompt. + # The tokens include the prompt. tokens: list[int] logprobs: list[dict[int, Logprob]] lora_request: Optional[LoRARequest] = None diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 4e8ac5162542f..93519b5ba1523 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,21 +11,26 @@ generation. Supported dataset types include: - HuggingFace - VisionArena """ +import ast import base64 import io import json import logging +import math import random from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Iterator, Mapping +from contextlib import suppress +from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast import numpy as np from PIL import Image from transformers import PreTrainedTokenizerBase +from typing_extensions import deprecated from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path @@ -68,13 +73,14 @@ class SampleRequest: Represents a single inference request for benchmarking. """ - prompt: Union[str, Any] + prompt: Union[str, list[str]] prompt_len: int expected_output_len: int multi_modal_data: Optional[ Union[MultiModalDataDict, dict, list[dict]] ] = None lora_request: Optional[LoRARequest] = None + request_id: Optional[str] = None # ----------------------------------------------------------------------------- @@ -111,7 +117,9 @@ class BenchmarkDataset(ABC): def apply_multimodal_chat_transformation( self, prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + mm_content: Optional[ + Union[MultiModalDataDict, dict, list[dict]] + ] = None) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -119,7 +127,15 @@ class BenchmarkDataset(ABC): """ content = [{"text": prompt, "type": "text"}] if mm_content is not None: - content.append(mm_content) + if isinstance(mm_content, list): + content.extend(cast(list[dict[str, Any]], mm_content)) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "Could not process multimodal content of type: " + + f"{type(mm_content)}" + ) return [{"role": "user", "content": content}] def load_data(self) -> None: @@ -182,7 +198,8 @@ class BenchmarkDataset(ABC): @abstractmethod def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int) -> list[SampleRequest]: + num_requests: int, + request_id_prefix: str = "") -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -193,6 +210,8 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. + request_id_prefix (str) The prefix of request_id. + Returns: list[SampleRequest]: A list of sample requests generated from the @@ -200,8 +219,12 @@ class BenchmarkDataset(ABC): """ raise NotImplementedError("sample must be implemented in subclasses.") - def maybe_oversample_requests(self, requests: list[SampleRequest], - num_requests: int) -> None: + def maybe_oversample_requests( + self, + requests: list[SampleRequest], + num_requests: int, + request_id_prefix: str = "", + ) -> None: """ Oversamples the list of requests if its size is less than the desired number. @@ -210,11 +233,17 @@ class BenchmarkDataset(ABC): requests (List[SampleRequest]): The current list of sampled requests. num_requests (int): The target number of requests. + request_id_prefix (str) The prefix of the request ids. + """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, - k=num_requests - len(requests)) + additional = deepcopy( + random.choices(requests, k=num_requests - len(requests)) + ) + for i in range(len(additional)): + req = additional[i] + req.request_id = request_id_prefix + str(len(requests) + i) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) @@ -265,7 +294,7 @@ def process_image(image: Any) -> Mapping[str, Any]: """ Process a single image input and return a multimedia content dictionary. - Supports three input types: + Supports the following input types: 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key containing raw image data. - Loads the bytes as a PIL.Image.Image. @@ -305,94 +334,592 @@ def process_image(image: Any) -> Mapping[str, Any]: " or str or dictionary with raw image bytes.") +def process_video(video: Any) -> Mapping[str, Any]: + """ + Process a single video input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key + containing raw video data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(video, dict) and 'bytes' in video: + video_bytes = video['bytes'] + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return { + "type": "video_url", + "video_url": { + "url": f"data:video/mp4;base64,{video_base64}" + }, + } + + if isinstance(video, str): + video_url = (video if video.startswith( + ("http://", "file://")) else f"file://{video}") + return {"type": "video_url", "video_url": {"url": video_url}} + + raise ValueError( + f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 + ) + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- class RandomDataset(BenchmarkDataset): + """ + Synthetic text-only dataset for serving/throughput benchmarks. + + Strategy: + - Sample input/output token lengths per request from integer-uniform ranges + around configured means (controlled by range_ratio). + - Prepend a fixed random prefix of length prefix_len. + - Generate the remaining tokens as a reproducible sequence: + (offset + index + arange(input_len)) % vocab_size. + - Decode then re-encode/truncate to ensure prompt token counts match. + - Uses numpy.default_rng seeded with random_seed for reproducible sampling. + """ # Default values copied from benchmark_serving.py for the random dataset. DEFAULT_PREFIX_LEN = 0 DEFAULT_RANGE_RATIO = 0.0 DEFAULT_INPUT_LEN = 1024 DEFAULT_OUTPUT_LEN = 128 - def __init__( - self, - **kwargs, - ) -> None: + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - random.seed(self.random_seed) - np.random.seed(self.random_seed) + # Use numpy's default_rng for deterministic sampling + # Do not use random.seed() or np.random.seed() elsewhere in this class. + # This ensures that the RNG is isolated from global RNG state. + self._rng = np.random.default_rng(self.random_seed) def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", prefix_len: int = DEFAULT_PREFIX_LEN, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, + batchsize: int = 1, **kwargs, ) -> list[SampleRequest]: - # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range" + + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer ) + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) vocab_size = tokenizer.vocab_size - num_special_tokens = tokenizer.num_special_tokens_to_add() - real_input_len = input_len - num_special_tokens - - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) - - # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(real_input_len * (1 - range_ratio)) - input_high = int(real_input_len * (1 + range_ratio)) - output_low = int(output_len * (1 - range_ratio)) - output_high = int(output_len * (1 + range_ratio)) - - # Add logging for debugging - logger.info( - "Sampling input_len from [%s, %s] and output_len from [%s, %s]", - input_low, input_high, output_low, output_high) - - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, - size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) requests = [] for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() - token_sequence = prefix_token_ids + inner_seq - prompt = tokenizer.decode(token_sequence) - # After decoding the prompt we have to encode and decode it again. - # This is done because in some cases N consecutive tokens - # give a string tokenized into != N number of tokens. - # For example for GPT2Tokenizer: - # [6880, 6881] -> ['Ġcalls', 'here'] -> - # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] - # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. - total_input_len = prefix_len + int(input_lens[i]) - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:total_input_len] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) requests.append( SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), - )) + request_id=request_id_prefix + str(i), + ) + ) + # only used for embeddings benchmark. + if batchsize > 1: + batch_requests = [] + # Create batched requests + for i in range(0, num_requests, batchsize): + batch = requests[i : i + batchsize] + batch_requests.append( + SampleRequest( + prompt=[req.prompt for req in batch], + prompt_len=sum(req.prompt_len for req in batch), + expected_output_len=0, + request_id=request_id_prefix + str(i // batchsize), + ) + ) + requests = batch_requests return requests + def get_prefix( + self, tokenizer: PreTrainedTokenizerBase, prefix_len: int + ) -> list[int]: + """ + Get the prefix for the dataset. + """ + return ( + self._rng.integers( + 0, tokenizer.vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) + + def get_sampling_params( + self, + num_requests: int, + range_ratio: float, + input_len: int, + output_len: int, + tokenizer: PreTrainedTokenizerBase, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get the sampling parameters for the dataset. + """ + # Enforce range_ratio < 1 + if not (0.0 <= range_ratio < 1.0): + raise ValueError("range_ratio must be in [0, 1).") + num_special_tokens = int(tokenizer.num_special_tokens_to_add()) + real_input_len = max(0, int(input_len) - num_special_tokens) + # Bounds use floor for low and ceil for high + input_low = math.floor(real_input_len * (1 - range_ratio)) + input_high = math.ceil(real_input_len * (1 + range_ratio)) + output_low = math.floor(output_len * (1 - range_ratio)) + output_high = math.ceil(output_len * (1 + range_ratio)) + # Ensure the lower bound for output length is at least 1 to + # prevent sampling 0 tokens. + output_low = max(output_low, 1) + + if input_low > input_high: + raise ValueError( + "Invalid input sampling interval: " + f"low={input_low} > high={input_high}" + ) + if output_low > output_high: + raise ValueError( + "Invalid output sampling interval: " + f"low={output_low} > high={output_high}" + ) + + logger.info( + "Sampling input_len from [%s, %s] and output_len from [%s, %s]", + input_low, + input_high, + output_low, + output_high, + ) + + input_lens = self._rng.integers(input_low, input_high + 1, + size=num_requests) + output_lens = self._rng.integers(output_low, output_high + 1, + size=num_requests) + offsets = self._rng.integers(0, tokenizer.vocab_size, + size=num_requests) + return input_lens, output_lens, offsets + + def generate_token_sequence( + self, + *, + tokenizer: PreTrainedTokenizerBase, + prefix_token_ids: list[int], + prefix_len: int, + vocab_size: int, + input_len: int, + offset: int, + index: int, + ) -> tuple[str, int]: + """ + Returns (prompt, total_input_len). + + NOTE: After decoding the prompt we have to encode and decode it again. + This is done because in some cases N consecutive tokens + give a string tokenized into != N number of tokens. + For example for GPT2Tokenizer: + [6880, 6881] -> ['Ġcalls', 'here'] -> + [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + To avoid uncontrolled change of the prompt length, + the encoded sequence is truncated before being decode again. + """ + # Build the inner sequence by sampling sequentially from the vocab + inner_seq = ((offset + index + np.arange(input_len)) + % vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + + # Decode, then re-encode and truncate to preserve token count invariants + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_len) + + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:total_input_len] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = len(re_encoded_sequence) + + return prompt, total_input_len + + +# ----------------------------------------------------------------------------- +# MultiModalDataset Implementation +# ----------------------------------------------------------------------------- + +class RandomMultiModalDataset(RandomDataset): + """ + Synthetic multimodal dataset (text + images) that extends RandomDataset. + + Status: + - Images: supported via synthetic RGB data. + - Video: not yet supported (TODO: implement video generation method). + - Audio: not yet supported. + + Sampling overview: + 1) Number of items per request is sampled uniformly from the integer range + [floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is + `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. + The maximum is further clamped to the sum of per-modality limits. + 2) Each item’s modality and shape is sampled from `bucket_config`, a dict + mapping (height, width, num_frames) → probability. We treat + `num_frames`=1 as image and and `num_frames` > 1 as video. + Entries with zero probability are removed and the rest are renormalized + to sum to 1. + 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. + When a modality reaches its cap, all of its buckets are excluded and the + remaining probabilities are renormalized. + + Example bucket configuration: + {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} + - Two image buckets (`num_frames`=1) and one video bucket + (`num_frames`=16). + OBS.: Only image sampling is supported for now. + """ + + IS_MULTIMODAL = True + # NOTE: video sampling is WIP. Setting it to 0. + DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0} + + DEFAULT_BASE_ITEMS_PER_REQUEST = 1 + DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0 + DEFAULT_MM_ITEM_BUCKET_CONFIG = { + (256, 256, 1): 0.5, + (720, 1280, 1): 0.5, + (720, 1280, 16): 0.0, + } + DEFAULT_ENABLE_MULTIMODAL_CHAT = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + + def generate_synthetic_image(self, width: int, height: int) -> Image.Image: + """Generate synthetic PIL image with random RGB values. + + NOTE: iid pixel sampling results in worst-case compression + (good for stressing I/O), but very unlike real photos. + We could consider a “low-freq” mode (e.g., noise blur) + to emulate network realism instead of max stress. + """ + random_pixels = self._rng.integers( + 0, + 256, + (height, width, 3), + dtype=np.uint8, + ) + return Image.fromarray(random_pixels) + + def generate_synthetic_video(self, width: int, + height: int, + num_frames: int) -> Any: + """Generate synthetic video with random values. + + TODO: Finish this method. + """ + raise NotImplementedError("Video sampling is WIP.") + + def map_config_to_modality(self, config: tuple[int, int, int]) -> str: + """Map the configuration to the modality.""" + if config[-1] == 1: + return "image" + elif config[-1] > 1: + return "video" + else: + raise ValueError(f"Invalid multimodal item configuration: {config}") + + def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], + float]) -> dict[tuple[int, int, int], float]: + """ + Remove zero probability entries + and normalize the bucket config to sum to 1. + """ + # Raise error if value is negative + if any(v < 0 for v in bucket_config.values()): + raise ValueError("Bucket config values must be non-negative.") + # Remove zero probability entries + bucket_config = {k: v for k, v in bucket_config.items() if v > 0} + # if bucket config is empty, raise error + if not bucket_config: + raise ValueError("Got invalid bucket config. " + "Bucket config values must be non-zero.") + # Normalize the remaining bucket config to sum to 1 + total = sum(bucket_config.values()) + return {k: v / total for k, v in bucket_config.items()} + + + def generate_mm_item(self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: + """ + Create synthetic images and videos and + apply process_image/process_video respectively. + This follows the OpenAI API chat completions + https://github.com/openai/openai-python + """ + + if self.map_config_to_modality(mm_item_config) == "image": + return process_image(self.generate_synthetic_image( + mm_item_config[1], + mm_item_config[0])) + elif self.map_config_to_modality(mm_item_config) == "video": + return process_video(self.generate_synthetic_video( + mm_item_config[1], + mm_item_config[0], + mm_item_config[2])) + else: + raise ValueError(f"Invalid multimodal item configuration: " + f"{mm_item_config}") + + + def get_mm_item_sampling_params( + self, + base_items_per_request: int, + num_mm_items_range_ratio: float, + limit_mm_per_prompt: dict[str, int], + bucket_config: dict[tuple[int, int, int], float], + ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]: + """ + Get the sampling parameters for the multimodal items. + """ + # Enforce num_mm_items_range_ratio <= 1 + if not (0.0 <= num_mm_items_range_ratio <= 1.0): + raise ValueError("num_mm_items_range_ratio must be in [0, 1].") + + # Ensure modalities to sample are in limit_mm_per_prompt + for k, v in bucket_config.items(): + # get modality from bucket config + modality = self.map_config_to_modality(k) + if modality not in limit_mm_per_prompt: + raise ValueError(f"Modality {modality} is not in " + f"limit_mm_per_prompt: " + f"{limit_mm_per_prompt.keys()}") + + # Remove zero probability entries + # and normalize bucket config to sum to 1 + bucket_config = self.normalize_bucket_config(bucket_config) + logger.info( + "Normalized bucket config: %s", bucket_config, + ) + # Only consider limit per prompt for modalities in bucket config + allowed_modalities = {self.map_config_to_modality(cfg) + for cfg in bucket_config} + limit_mm_per_prompt = { + k: v for k, v in limit_mm_per_prompt.items() + if k in allowed_modalities} + if not limit_mm_per_prompt: + raise ValueError("No valid limits for modalities present in " + "bucket_config.") + + logger.info( + "Updated mm-limit-per-prompt: %s", limit_mm_per_prompt, + ) + + # Get max and min num mm items and ensure + # it is at most the sum of limit_mm_per_prompt for all modalities + max_num_mm_items = min( + sum(limit_mm_per_prompt.values()), + math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) + ) + # Ensure min num mm items is at least 0 + min_num_mm_items = max( + 0, + math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) + ) + # Raise error if min num mm items is greater than max num mm items + if min_num_mm_items > max_num_mm_items: + raise ValueError(f"Min num mm items is greater than max mm items: " + f"{min_num_mm_items} > {max_num_mm_items}") + + logger.info( + "Sampling number of multimodal items from [%s, %s]", + min_num_mm_items, max_num_mm_items, + ) + + return ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) + + def get_mm_item_iterator( + self, + min_num_mm_items: int, + max_num_mm_items: int, + bucket_config: dict[tuple[int, int, int], float], + limit_mm_per_prompt: dict[str, int], + ) -> Iterator[tuple[int,int, int]]: + """ + Iterator over the multimodal items for each request + whose size is between min_num_mm_items and max_num_mm_items. + + Loop over the bucket config and sample a multimodal item. + Loop until the number of multimodal items sampled is equal to + request_num_mm_items or limit of multimodal items per prompt + for all modalities is reached. + + Note: + - This function operates on a per-request shallow copy of + `bucket_config` (tuple->float). The original dict passed to + `sample` is not mutated. If this ever changes, a test + is implemented and will fail. + """ + # Get the number of multimodal items to sample + request_num_mm_items = int( + self._rng.integers(min_num_mm_items, max_num_mm_items + 1) + ) + # If request_num_mm_items is 0, yield an empty iterator + if request_num_mm_items == 0: + return + # Initialize modality counters + modality_counter = {self.map_config_to_modality(k): 0 + for k in bucket_config} + # Copy the bucket config to avoid modifying the original + bucket_config_copy = bucket_config.copy() + # Loop over the number of multimodal items to sample + while sum(modality_counter.values()) < request_num_mm_items: + # Sample a multimodal item config + mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), + p=list(bucket_config_copy.values())) + modality = self.map_config_to_modality(mm_item_config) + # Check that modality count is less than limit per prompt + if modality_counter[modality] < limit_mm_per_prompt[modality]: + modality_counter[modality] += 1 + yield ( + mm_item_config + ) + else: + # If the counter is greater than the limit per prompt + # set all multimodal items of this modality to 0 + for k, v in bucket_config_copy.items(): + if self.map_config_to_modality(k) == modality: + bucket_config_copy[k] = 0 + # If all configs are 0, break the loop + # This should not happen as request_num_mm_items is at most + # the sum of limit_mm_per_prompt for all modalities + if all(v == 0 for v in bucket_config_copy.values()): + logger.warning("Exhausted all multimodal items " + "of modality %s", + modality) + break + # Renormalize the bucket config + bucket_config_copy = self.normalize_bucket_config( + bucket_config_copy) + + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, + limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, + base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, + num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + bucket_config: dict[tuple[int, int, int], float] = + DEFAULT_MM_ITEM_BUCKET_CONFIG, + enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, + **kwargs, + ) -> list[SampleRequest]: + + # NOTE: Video sampling is WIP. Raise error if video is in bucket config + # and probability is non-zero. + if any(self.map_config_to_modality(cfg) == "video" and p > 0 + for cfg, p in bucket_config.items()): + raise NotImplementedError("Video sampling not implemented; " + "set its probability to 0.") + + # Get the sampling parameters for the dataset + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer + ) + + ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) = self.get_mm_item_sampling_params( + base_items_per_request, + num_mm_items_range_ratio, + limit_mm_per_prompt, + bucket_config, + ) + + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) + vocab_size = tokenizer.vocab_size + # Add synthetic multimodal items to each request + mm_requests = [] + for i in range(num_requests): + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + # Get multimodal item iterator for a given request + mm_item_iterator = self.get_mm_item_iterator( + min_num_mm_items, + max_num_mm_items, + bucket_config, + limit_mm_per_prompt, + ) + + mm_content = cast(list[dict[str, Any]], [ + self.generate_mm_item(mm_item_config) + for mm_item_config in mm_item_iterator + ]) + + if enable_multimodal_chat: + # NOTE: For now this option is only provided for completeness + # given that the serve.py benchmark currently does not use it. + mm_chat_prompt: Any = prompt + mm_chat_prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sample_request = SampleRequest( + prompt=mm_chat_prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=None, + request_id=request_id_prefix + str(i), + ) + else: + sample_request = SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), + ) + mm_requests.append(sample_request) + return mm_requests # ----------------------------------------------------------------------------- # ShareGPT Dataset Implementation @@ -431,9 +958,11 @@ class ShareGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: samples: list = [] + ind = 0 for entry in self.data: if len(samples) >= num_requests: break @@ -454,17 +983,26 @@ class ShareGPTDataset(BenchmarkDataset): skip_min_output_len_check=output_len is not None): continue + if image_path := entry.get("image"): + mm_content = process_image(image_path) + elif video_path := entry.get("video"): + mm_content = process_video(video_path) + else: + mm_content = None if enable_multimodal_chat: prompt = self.apply_multimodal_chat_transformation( - prompt, None) + prompt, mm_content) samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, lora_request=lora_request, + multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) - self.maybe_oversample_requests(samples, num_requests) + ind += 1 + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -480,7 +1018,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-name", type=str, default="random", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], + choices=[ + "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", + "custom", "prefix_repetition" + ], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -579,6 +1120,103 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "context length sampled from [input_len * (1 - range_ratio), " "input_len * (1 + range_ratio)]."), ) + random_group.add_argument( + "--random-batch-size", + type=int, + default=1, + help=("Batch size for random sampling. " + "Only used for embeddings benchmark."), + ) + + # random multimodal dataset options + random_mm_group = parser.add_argument_group( + "random multimodal dataset options extended from random dataset") + random_mm_group.add_argument( + "--random-mm-base-items-per-request", + type=int, + default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST, + help=( + "Base number of multimodal items per request for random-mm. " + "Actual per-request count is sampled around this base using " + "--random-mm-num-mm-items-range-ratio." + ), + ) + random_mm_group.add_argument( + "--random-mm-num-mm-items-range-ratio", + type=float, + default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + help=( + "Range ratio r in [0, 1] for sampling items per request. " + "We sample uniformly from the closed integer range " + "[floor(n*(1-r)), ceil(n*(1+r))] " + "where n is the base items per request. " + "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped " + "to the sum of per-modality limits from " + "--random-mm-limit-mm-per-prompt. " + "An error is raised if the computed min exceeds the max." + ), + ) + random_mm_group.add_argument( + "--random-mm-limit-mm-per-prompt", + type=json.loads, + default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, + help=( + "Per-modality hard caps for items attached per request, e.g. " + "'{\"image\": 3, \"video\": 0}'. The sampled per-request item " + "count is clamped to the sum of these limits. When a modality " + "reaches its cap, its buckets are excluded and probabilities are " + "renormalized." + "OBS.: Only image sampling is supported for now." + ), + ) + + def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]: + # If already a dict (e.g., programmatic call), normalize keys + def normalize(d: dict) -> dict[tuple[int, int, int], float]: + out: dict[tuple[int, int, int], float] = {} + for k, val in d.items(): + key = k + if isinstance(key, str): + with suppress(Exception): + key = ast.literal_eval(key) + if not (isinstance(key, tuple) and len(key) == 3 + and all(isinstance(x, int) for x in key)): + raise ValueError( + f"Invalid bucket key {k!r}. Expected tuple (H, W, T)." + ) + out[(int(key[0]), int(key[1]), int(key[2]))] = float(val) + return out + + if isinstance(v, dict): + return normalize(v) + if isinstance(v, str): + # Python literal (supports tuple keys) + parsed = ast.literal_eval(v) + if not isinstance(parsed, dict): + raise ValueError("Bucket config must parse to a dict.") + return normalize(parsed) + raise ValueError("Unsupported value for --random-mm-bucket-config.") + + random_mm_group.add_argument( + "--random-mm-bucket-config", + type=_parse_mm_bucket_config, + default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG, + help=( + "The bucket config is a dictionary mapping a multimodal item" + "sampling configuration to a probability." + "Currently allows for 2 modalities: images and videos. " + "An bucket key is a tuple of (height, width, num_frames)" + "The value is the probability of sampling that specific item. " + "Example: " + "--random-mm-bucket-config " + "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} " + "First item: images with resolution 256x256 w.p. 0.5" + "Second item: images with resolution 720x1280 w.p. 0.4 " + "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1" + "OBS.: If the probabilities do not sum to 1, they are normalized." + "OBS bis.: Only image sampling is supported for now." + ), + ) hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset", @@ -597,6 +1235,37 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "from the sampled HF dataset.", ) + prefix_repetition_group = parser.add_argument_group( + "prefix repetition dataset options") + prefix_repetition_group.add_argument( + "--prefix-repetition-prefix-len", + type=int, + default=256, + help="Number of prefix tokens per request, used only for prefix " + "repetition dataset.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-suffix-len", + type=int, + default=256, + help="Number of suffix tokens per request, used only for prefix " + "repetition dataset. Total input length is prefix_len + suffix_len.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-num-prefixes", + type=int, + default=10, + help="Number of prefixes to generate, used only for prefix repetition " + "dataset. Prompts per prefix is num_requests // num_prefixes.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for prefix " + "repetition dataset.", + ) + def get_samples(args, tokenizer) -> list[SampleRequest]: if args.dataset_name == "custom": @@ -606,6 +1275,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "sonnet": @@ -619,6 +1289,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=False, + request_id_prefix=args.request_id_prefix, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -630,6 +1301,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=True, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "hf": @@ -675,10 +1347,11 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: "openai-chat", "openai-audio", ]: - # multi-modal benchmark is only available on OpenAI Chat backend. + # multi-modal benchmark is only available on OpenAI Chat + # endpoint-type. raise ValueError( "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' backend.") + "'openai-audio' endpoint-type.") input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -689,35 +1362,77 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, + request_id_prefix=args.request_id_prefix, ) else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts), - "random": - lambda: RandomDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + ), + "random": lambda: RandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, + batchsize=args.random_batch_size, + ), + "random-mm": + lambda: RandomMultiModalDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + range_ratio=args.random_range_ratio, + input_len=args.random_input_len, + output_len=args.random_output_len, + base_items_per_request=args.random_mm_base_items_per_request, + limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, + num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, + bucket_config=args.random_mm_bucket_config, + request_id_prefix=args.request_id_prefix, + ), + "prefix_repetition": + lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.prefix_repetition_prefix_len, + suffix_len=args.prefix_repetition_suffix_len, + num_prefixes=args.prefix_repetition_num_prefixes, + output_len=args.prefix_repetition_output_len, + request_id_prefix=args.request_id_prefix, ), } try: + # Enforce endpoint compatibility for multimodal datasets. + if args.dataset_name == "random-mm" and args.endpoint_type not in [ + "openai-chat"]: + raise ValueError( + "Multi-modal content (images) is only supported on " + "'openai-chat' backend." + ) input_requests = dataset_mapping[args.dataset_name]() except KeyError as err: raise ValueError(f"Unknown dataset: {args.dataset_name}") from err @@ -787,10 +1502,11 @@ class CustomDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["prompt"] @@ -812,8 +1528,10 @@ class CustomDataset(BenchmarkDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -822,7 +1540,9 @@ class CustomDataset(BenchmarkDataset): # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- - +@deprecated( + "SonnetDataset is deprecated and will be removed in a future version.", +) class SonnetDataset(BenchmarkDataset): """ Simplified implementation of the Sonnet dataset. Loads poem lines from a @@ -855,6 +1575,7 @@ class SonnetDataset(BenchmarkDataset): input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -880,6 +1601,7 @@ class SonnetDataset(BenchmarkDataset): prefix_lines = self.data[:num_prefix_lines] samples = [] + ind = 0 while len(samples) < num_requests: extra_lines = random.choices(self.data, k=num_input_lines - num_prefix_lines) @@ -895,7 +1617,9 @@ class SonnetDataset(BenchmarkDataset): if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(ind), )) + ind += 1 return samples @@ -946,6 +1670,7 @@ class BurstGPTDataset(BenchmarkDataset): num_requests: int, max_loras: Optional[int] = None, lora_path: Optional[str] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: samples = [] @@ -966,6 +1691,7 @@ class BurstGPTDataset(BenchmarkDataset): prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, + request_id=request_id_prefix + str(i), )) return samples @@ -1021,11 +1747,13 @@ class ConversationDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter( lambda x: len(x["conversations"]) >= 2) sampled_requests = [] + ind = 0 dynamic_output = output_len is None for item in filtered_data: @@ -1057,8 +1785,11 @@ class ConversationDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1087,12 +1818,13 @@ class VisionArenaDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) @@ -1114,8 +1846,10 @@ class VisionArenaDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1144,15 +1878,18 @@ class InstructCoderDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - prompt = f"{item['input']}\n\n{item['instruction']} Just output \ - the code, do not include any explanation." + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) # apply template prompt = tokenizer.apply_chat_template( @@ -1170,8 +1907,10 @@ class InstructCoderDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1201,13 +1940,14 @@ class MTBenchDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["turns"][0] @@ -1228,8 +1968,10 @@ class MTBenchDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1251,8 +1993,10 @@ class AIMODataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs) -> list: sampled_requests = [] + ind = 0 dynamic_output = output_len is None for item in self.data: @@ -1277,8 +2021,12 @@ class AIMODataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, + request_id=request_id_prefix + str(ind), + )) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1349,13 +2097,14 @@ class NextEditPredictionDataset(HuggingFaceDataset): } def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", **kwargs): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( self.dataset_path) if formatting_prompt_func is None: raise ValueError(f"Unsupported dataset path: {self.dataset_path}") samples = [] - for sample in self.data: + for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) samples.append( SampleRequest( @@ -1363,10 +2112,11 @@ class NextEditPredictionDataset(HuggingFaceDataset): prompt_len=len(tokenizer(sample["prompt"]).input_ids), expected_output_len=len( tokenizer(sample["expected_output"]).input_ids), + request_id=request_id_prefix + str(i), )) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests) + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -1416,6 +2166,7 @@ class ASRDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len @@ -1423,6 +2174,7 @@ class ASRDataset(HuggingFaceDataset): prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] + ind = 0 skipped = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -1442,7 +2194,9 @@ class ASRDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) + ind += 1 if skipped: logger.warning( "%d samples discarded from dataset due to" @@ -1450,7 +2204,8 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1487,11 +2242,13 @@ class MLPerfDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: # Force dynamic output length based on reference completion. dynamic_output = output_len is None sampled_requests: list[SampleRequest] = [] + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -1526,8 +2283,93 @@ class MLPerfDataset(HuggingFaceDataset): prompt=prompt_formatted, prompt_len=prompt_len, expected_output_len=expected_output_len, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests + + +# ----------------------------------------------------------------------------- +# Prefix Repetition Dataset Implementation +# ----------------------------------------------------------------------------- + + +class PrefixRepetitionRandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the repeated prefix + # dataset. + DEFAULT_PREFIX_LEN = 256 + DEFAULT_SUFFIX_LEN = 256 + DEFAULT_NUM_PREFIXES = 10 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + random.seed(self.random_seed) + np.random.seed(self.random_seed) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + suffix_len: int = DEFAULT_SUFFIX_LEN, + num_prefixes: int = DEFAULT_NUM_PREFIXES, + output_len: int = DEFAULT_OUTPUT_LEN, + request_id_prefix: str = "", + **kwargs, + ) -> list[SampleRequest]: + vocab_size = tokenizer.vocab_size + prompts_per_prefix = num_requests // num_prefixes + if prompts_per_prefix == 0: + raise ValueError( + f"num_requests ({num_requests}) must be greater than or equal " + f"to num_prefixes ({num_prefixes})" + ) + + def _generate_exact_length_tokens(target_length: int) -> list[int]: + """Generate tokens that decode and re-encode to exactly + target_length.""" + # Generate random tokens + tokens = np.random.randint( + 0, vocab_size, size=target_length).tolist() + text = tokenizer.decode(tokens) + re_encoded = tokenizer.encode(text, add_special_tokens=False) + + if len(re_encoded) == target_length: + return re_encoded + elif len(re_encoded) < target_length: + # Recursively generate additional consistent tokens + needed = target_length - len(re_encoded) + extra_tokens = _generate_exact_length_tokens(needed) + return re_encoded + extra_tokens + else: + # Truncate to target length + return re_encoded[:target_length] + + requests = [] + for _ in range(num_prefixes): + prefix_tokens = _generate_exact_length_tokens(prefix_len) + + for _ in range(prompts_per_prefix): + suffix_tokens = _generate_exact_length_tokens(suffix_len) + + combined_tokens = prefix_tokens + suffix_tokens + prompt = tokenizer.decode(combined_tokens) + prompt_len = len(combined_tokens) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + + random.shuffle(requests) + return requests diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 47bc288774504..6bb2a497119e9 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -9,7 +9,7 @@ import sys import time import traceback from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union import aiohttp from tqdm.asyncio import tqdm @@ -28,9 +28,10 @@ class RequestFuncInput: model_name: Optional[str] = None logprobs: Optional[int] = None extra_body: Optional[dict] = None - multi_modal_content: Optional[dict | list[dict]] = None + multi_modal_content: Optional[Union[dict, list[dict]]] = None ignore_eos: bool = False language: Optional[str] = None + request_id: Optional[str] = None @dataclass @@ -68,8 +69,8 @@ async def async_request_openai_completions( ), "OpenAI Completions API URL must end with 'completions' or 'profile'." payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "repetition_penalty": 1.0, @@ -87,6 +88,8 @@ async def async_request_openai_completions( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -132,7 +135,7 @@ async def async_request_openai_completions( # Decoding phase else: output.itl.append(timestamp - - most_recent_timestamp) + most_recent_timestamp) most_recent_timestamp = timestamp generated_text += text or "" @@ -210,6 +213,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -249,7 +254,7 @@ async def async_request_openai_chat_completions( # Decoding phase else: output.itl.append(timestamp - - most_recent_timestamp) + most_recent_timestamp) generated_text += content or "" elif usage := data.get("usage"): @@ -311,6 +316,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id # Send audio file def to_bytes(y, sr): @@ -387,12 +394,61 @@ async def async_request_openai_audio( return output +async def async_request_openai_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +): + api_url = request_func_input.api_url + assert api_url.endswith( + "embeddings" + ), "OpenAI Embeddings API URL must end with 'embeddings'." + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + payload = { + "model": request_func_input.model, + "input": request_func_input.prompt, + } + + output = RequestFuncOutput() + st = time.perf_counter() + try: + async with session.post( + url=api_url, + headers=headers, + json=payload + ) as response: + if response.status == 200: + output.latency = time.perf_counter() - st + data = await response.json() + output.success = True + output.generated_text = "" + output.prompt_len = data.get( + "usage", {}).get( + "prompt_tokens", 0) + else: + output.success = False + output.error = response.reason or "" + except Exception as e: + output.success = False + output.error = str(e) + + if pbar: + pbar.update(1) + return output + + # TODO: Add more request functions for different API protocols. ASYNC_REQUEST_FUNCS = { "vllm": async_request_openai_completions, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, "openai-audio": async_request_openai_audio, + "openai-embeddings": async_request_openai_embeddings, } OPENAI_COMPATIBLE_BACKENDS = [ diff --git a/vllm/benchmarks/lib/utils.py b/vllm/benchmarks/lib/utils.py index 5f95fdcc75829..0c27687dcf16d 100644 --- a/vllm/benchmarks/lib/utils.py +++ b/vllm/benchmarks/lib/utils.py @@ -54,7 +54,12 @@ class InfEncoder(json.JSONEncoder): def clear_inf(self, o: Any): if isinstance(o, dict): - return {k: self.clear_inf(v) for k, v in o.items()} + return { + str(k) + if not isinstance(k, (str, int, float, bool, type(None))) + else k: self.clear_inf(v) + for k, v in o.items() + } elif isinstance(o, list): return [self.clear_inf(v) for v in o] elif isinstance(o, float) and math.isinf(o): diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 7bf04c7532411..abb838316cd31 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -4,7 +4,7 @@ r"""Benchmark online serving throughput. On the server side, run one of the following commands to launch the vLLM OpenAI API server: - vllm serve <your_model> <engine arguments> + vllm serve <your_model> <engine arguments> On the client side, run: vllm bench serve \ @@ -26,6 +26,7 @@ import warnings from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime +from enum import Enum from typing import Any, Literal, Optional import aiohttp @@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer MILLISECONDS_TO_SECONDS_CONVERSION = 1000 +class TaskType(Enum): + GENERATION = "generation" + EMBEDDING = "embedding" + + @dataclass class BenchmarkMetrics: completed: int @@ -75,6 +81,16 @@ class BenchmarkMetrics: std_e2el_ms: float percentiles_e2el_ms: list[tuple[float, float]] +@dataclass +class EmbedBenchmarkMetrics: + completed: int + total_input: int + request_throughput: float + total_token_throughput :float + mean_e2el_ms: float + std_e2el_ms: float + median_e2el_ms: float + percentiles_e2el_ms: float def _get_current_request_rate( ramp_up_strategy: Optional[Literal["linear", "exponential"]], @@ -146,11 +162,11 @@ async def get_request( delay_ts = [] for request_index, request in enumerate(input_requests): current_request_rate = _get_current_request_rate(ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - request_index, - total_requests, - request_rate) + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) @@ -160,7 +176,7 @@ async def get_request( # Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution. delay_ts.append(np.random.gamma(shape=burstiness, scale=theta)) - + # Calculate the cumulative delay time from the first sent out requests. for i in range(1, len(delay_ts)): delay_ts[i] += delay_ts[i - 1] @@ -170,11 +186,11 @@ async def get_request( # logic would re-scale delay time to ensure the final delay_ts # align with target_total_delay_s. # - # NOTE: If we simply accumulate the random delta values - # from the gamma distribution, their sum would have 1-2% gap + # NOTE: If we simply accumulate the random delta values + # from the gamma distribution, their sum would have 1-2% gap # from target_total_delay_s. The purpose of the following logic is to - # close the gap for stablizing the throughput data - # from different random seeds. + # close the gap for stablizing the throughput data + # from different random seeds. target_total_delay_s = total_requests / request_rate normalize_factor = target_total_delay_s / delay_ts[-1] delay_ts = [delay * normalize_factor for delay in delay_ts] @@ -189,6 +205,51 @@ async def get_request( yield request, request_rates[request_index] +def calculate_metrics_for_embeddings( + outputs: list[RequestFuncOutput], + dur_s: float, + selected_percentiles: list[float] +) -> EmbedBenchmarkMetrics: + """Calculate the metrics for the embedding requests. + + Args: + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + selected_percentiles: The percentiles to select. + + Returns: + The calculated benchmark metrics. + """ + total_input = 0 + completed = 0 + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + e2els.append(outputs[i].latency) + completed += 1 + total_input += outputs[i].prompt_len + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = EmbedBenchmarkMetrics( + completed=completed, + total_input=total_input, + request_throughput=completed / dur_s, + total_token_throughput=total_input / dur_s, + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles + ], + ) + return metrics + + def calculate_metrics( input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], @@ -334,8 +395,16 @@ async def benchmark( ramp_up_end_rps: Optional[int] = None, ready_check_timeout_sec: int = 600, ): + task_type = ( + TaskType.EMBEDDING + if api_url.endswith("/v1/embeddings") + else TaskType.GENERATION + ) if endpoint_type in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + if task_type == TaskType.EMBEDDING: + request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] + else: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] else: raise ValueError(f"Unknown endpoint_type: {endpoint_type}") @@ -421,8 +490,8 @@ async def benchmark( if profile_output.success: print("Profiler started") - distribution = ("Poisson process" if burstiness == 1.0 - else "Gamma distribution") + distribution = ("Poisson process" if burstiness == 1.0 + else "Gamma distribution") if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") @@ -449,7 +518,7 @@ async def benchmark( session=session, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, + return await request_func(request_func_input=request_func_input, session=session, pbar=pbar) @@ -478,11 +547,12 @@ async def benchmark( "timestamp": timestamp }) last_int_rps = current_int_rps - prompt, prompt_len, output_len, mm_content = ( + prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, request.prompt_len, request.expected_output_len, request.multi_modal_data, + request.request_id, ) req_model_id, req_model_name = model_id, model_name if lora_modules: @@ -498,7 +568,8 @@ async def benchmark( logprobs=logprobs, multi_modal_content=mm_content, ignore_eos=ignore_eos, - extra_body=extra_body) + extra_body=extra_body, + request_id=request_id,) tasks.append( asyncio.create_task( limited_request_func(request_func_input=request_func_input, @@ -511,14 +582,22 @@ async def benchmark( benchmark_duration = time.perf_counter() - benchmark_start_time - metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - selected_percentiles=selected_percentiles, - goodput_config_dict=goodput_config_dict, - ) + if task_type == TaskType.GENERATION: + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + else: + metrics = calculate_metrics_for_embeddings( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + ) + actual_output_lens = 0 print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) @@ -527,39 +606,55 @@ async def benchmark( max_concurrency)) if request_rate != float('inf'): print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", - request_rate )) + request_rate)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) + if isinstance(metrics, BenchmarkMetrics): + print("{:<40} {:<10}".format( + "Total generated tokens:", metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) if goodput_config_dict: print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) + if isinstance(metrics, BenchmarkMetrics): + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput)) - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "request_goodput": - metrics.request_goodput if goodput_config_dict else None, - "output_throughput": metrics.output_throughput, - "total_token_throughput": metrics.total_token_throughput, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": actual_output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - } + if isinstance(metrics, BenchmarkMetrics): + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + else: + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "request_throughput": metrics.request_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "errors": [output.error for output in outputs], + } if rps_change_events: result["rps_change_events"] = rps_change_events @@ -596,10 +691,11 @@ async def benchmark( value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value - process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") - process_one_metric("itl", "ITL", "Inter-token Latency") + if task_type == TaskType.GENERATION: + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric( + "tpot", "TPOT", "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") print("=" * 50) @@ -730,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", @@ -741,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -865,6 +961,14 @@ def add_cli_args(parser: argparse.ArgumentParser): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) + parser.add_argument( + "--request-id-prefix", + type=str, + required=False, + default="benchmark-serving", + help="Specify the prefix of request id.", + ) + sampling_group = parser.add_argument_group("sampling parameters") sampling_group.add_argument( @@ -958,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser): def main(args: argparse.Namespace) -> dict[str, Any]: return asyncio.run(main_async(args)) + async def main_async(args: argparse.Namespace) -> dict[str, Any]: print(args) random.seed(args.seed) @@ -1036,32 +1141,32 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: gc.freeze() benchmark_result = await benchmark( - endpoint_type=args.endpoint_type, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - extra_body=sampling_params, - ramp_up_strategy=args.ramp_up_strategy, - ramp_up_start_rps=args.ramp_up_start_rps, - ramp_up_end_rps=args.ramp_up_end_rps, - ready_check_timeout_sec=args.ready_check_timeout_sec, - ) + endpoint_type=args.endpoint_type, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, + ready_check_timeout_sec=args.ready_check_timeout_sec, + ) # Save config and results to json result_json: dict[str, Any] = {} @@ -1088,7 +1193,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Traffic result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") + < float("inf") else "inf") result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -1122,7 +1227,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.max_concurrency is not None else "") label = label or endpoint_type if args.ramp_up_strategy is not None: - file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa else: file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: @@ -1139,4 +1244,4 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) - return result_json \ No newline at end of file + return result_json diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index fdf6548ada5b6..f022a55e625f5 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -18,9 +18,11 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, ConversationDataset, - InstructCoderDataset, RandomDataset, - SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) + InstructCoderDataset, + PrefixRepetitionRandomDataset, + RandomDataset, SampleRequest, + ShareGPTDataset, SonnetDataset, + VisionArenaDataset) from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, write_to_json) from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs @@ -327,6 +329,12 @@ def get_requests(args, tokenizer): dataset_cls = AIMODataset common_kwargs['dataset_subset'] = None common_kwargs['dataset_split'] = "train" + elif args.dataset_name == "prefix_repetition": + dataset_cls = PrefixRepetitionRandomDataset + sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len + sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len + sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes + sample_kwargs["output_len"] = args.prefix_repetition_output_len else: raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values @@ -356,7 +364,11 @@ def validate_args(args): raise ValueError(f"Unsupported backend: {args.backend}") # === Dataset Configuration === - if not args.dataset and not args.dataset_path: + if ( + not args.dataset + and not args.dataset_path + and args.dataset_name not in {"prefix_repetition"} + ): print( "When dataset path is not set, it will default to random dataset") args.dataset_name = 'random' @@ -422,6 +434,14 @@ def validate_args(args): if args.backend == "mii" and args.tokenizer != args.model: raise ValueError( "Tokenizer must be the same as the model for MII backend.") + + # --data-parallel is not supported currently. + # https://github.com/vllm-project/vllm/issues/16222 + if args.data_parallel_size > 1: + raise ValueError( + "Data parallel is not supported in offline benchmark, " + "please use benchmark serving instead" + ) def add_cli_args(parser: argparse.ArgumentParser): @@ -432,7 +452,10 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--dataset-name", type=str, - choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + choices=[ + "sharegpt", "random", "sonnet", "burstgpt", "hf", + "prefix_repetition" + ], help="Name of the dataset to benchmark on.", default="sharegpt") parser.add_argument( @@ -521,6 +544,38 @@ def add_cli_args(parser: argparse.ArgumentParser): default=None, help="Split of the HF dataset.") + # prefix repetition dataset + prefix_repetition_group = parser.add_argument_group( + "prefix repetition dataset options") + prefix_repetition_group.add_argument( + "--prefix-repetition-prefix-len", + type=int, + default=None, + help="Number of prefix tokens per request, used only for prefix " + "repetition dataset.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-suffix-len", + type=int, + default=None, + help="Number of suffix tokens per request, used only for prefix " + "repetition dataset. Total input length is prefix_len + suffix_len.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-num-prefixes", + type=int, + default=None, + help="Number of prefixes to generate, used only for prefix repetition " + "dataset. Prompts per prefix is num_requests // num_prefixes.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-output-len", + type=int, + default=None, + help="Number of output tokens per request, used only for prefix " + "repetition dataset.", + ) + parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 059e7a3b29761..fa86773d24743 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): def __init__(self, module: torch.fx.GraphModule, compile_submod_names: list[str], vllm_config: VllmConfig, - graph_pool, vllm_backend: "VllmBackend"): + vllm_backend: "VllmBackend"): super().__init__(module) from torch._guards import detect_fake_mode self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool self.vllm_config = vllm_config self.vllm_backend = vllm_backend # When True, it annoyingly dumps the torch.fx.Graph on errors. @@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): runnable=piecewise_backend, vllm_config=self.vllm_config, runtime_mode=CUDAGraphMode.PIECEWISE, - graph_pool=self.graph_pool, cudagraph_options=CUDAGraphOptions( debug_log_enable=piecewise_backend.is_first_graph, gc_disable=not piecewise_backend.is_first_graph, @@ -405,7 +403,6 @@ class VllmBackend: vllm_config: VllmConfig compilation_config: CompilationConfig - graph_pool: Any _called: bool = False # the graph we compiled graph: fx.GraphModule @@ -433,13 +430,6 @@ class VllmBackend: # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global_graph_pool = current_platform.get_global_graph_pool() - - # TODO: in the future, if we want to use multiple - # streams, it might not be safe to share a global pool. - # only investigate this when we use multiple streams - self.graph_pool = global_graph_pool - # Passes to run on the graph post-grad. self.post_grad_pass_manager = PostGradPassManager() @@ -484,7 +474,7 @@ class VllmBackend: factors = [] # 0. factors come from the env, for example, The values of - # VLLM_PP_LAYER_PARTITION will affects the computation graph. + # VLLM_PP_LAYER_PARTITION will affect the computation graph. env_hash = envs.compute_hash() factors.append(env_hash) @@ -586,7 +576,7 @@ class VllmBackend: # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.vllm_config, self.graph_pool, + self.vllm_config, self).run(*example_inputs) graph_path = os.path.join(local_cache_dir, "computation_graph.py") diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 1c3f52c533b13..161d066ce9fb8 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol): """ def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs): + runtime_mode: CUDAGraphMode, **kwargs): """ Initializes the StaticGraphWrapper class with graph capturing and execution-related configurations. @@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol): graph runtime. See CUDAGraphMode in vllm/config.py. Note that only the subset enum `NONE`, `PIECEWISE` and `FULL` are used as concrete runtime mode for cudagraph dispatching. - graph_pool (Any): - Graph memory pool handle, e.g., - `torch.cuda.graph_pool_handle()`. Keyword Args: kwargs: Additional keyword arguments for platform-specific configurations. diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 6ae50245ed3a8..0c545d8cffd24 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -401,6 +402,18 @@ if flashinfer_comm is not None: 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB } + + try: + _FI_MAX_SIZES.update({ + int(k): int(float(v) * MiB) + for k, v in + envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() + }) + except Exception as e: + raise ValueError( + "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + + str(e)) from e + # opt for a more conservative default value # when world size is not in _FI_MAX_SIZES _DEFAULT_FI_MAX_SIZE = MiB // 2 @@ -465,7 +478,8 @@ if flashinfer_comm is not None: quant_out=quant_out, scale_out=scale_out, # in vllm we only support swizzled layout - layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED, + layout_code=flashinfer_comm.QuantizationSFLayout. + SWIZZLED_128x4, scale_factor=scale_factor, ) else: diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 65a38197ad4e2..e233f959c0a4a 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -67,11 +67,9 @@ class CUDAGraphWrapper: runnable: Callable, vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, - graph_pool: Any = None, cudagraph_options: Optional[CUDAGraphOptions] = None): self.runnable = runnable self.vllm_config = vllm_config - self.graph_pool = graph_pool self.runtime_mode = runtime_mode self.compilation_config = vllm_config.compilation_config @@ -81,8 +79,10 @@ class CUDAGraphWrapper: # assert runtime_mode is not NONE(no cudagraph), otherwise, we don't # need to initialize a CUDAGraphWrapper. assert self.runtime_mode != CUDAGraphMode.NONE - if self.graph_pool is None: - self.graph_pool = current_platform.get_global_graph_pool() + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = current_platform.get_global_graph_pool() if cudagraph_options is None: cudagraph_options = CUDAGraphOptions() diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 1370862d580a5..41d9fcb824b01 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -52,6 +52,14 @@ def _should_ignore_torch_compile(cls) -> bool: return getattr(cls, IGNORE_COMPILE_KEY, False) +@overload +def support_torch_compile( + *, + enable_if: Optional[Callable[[VllmConfig], bool]] = None, +) -> Callable[[_T], _T]: + ... + + @overload def support_torch_compile( *, @@ -69,6 +77,7 @@ def support_torch_compile( cls: Optional[_T] = None, *, dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, + enable_if: Optional[Callable[[VllmConfig], bool]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -118,6 +127,11 @@ def support_torch_compile( NOTE: if an argument is `None`, it should always be passed as `None` during the lifetime of the model, otherwise, it cannot be captured as a single computation graph. + + `enable_if` is a function that takes a `VllmConfig` object as input and + returns a boolean value indicating whether to compile the model or not. + This is useful if you want to compile the model only when certain + conditions are met. """ def cls_decorator_helper(cls: _T) -> _T: @@ -149,7 +163,8 @@ def support_torch_compile( if k not in sig.parameters: raise ValueError( f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, inferred_dynamic_arg_dims) + return _support_torch_compile(cls, inferred_dynamic_arg_dims, + enable_if) if cls is not None: # use `support_torch_compile` as a decorator without arguments @@ -162,6 +177,7 @@ def support_torch_compile( def _support_torch_compile( cls: _T, dynamic_arg_dims: dict[str, Union[int, list[int]]], + enable_if: Optional[Callable[[VllmConfig], bool]] = None, ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -182,13 +198,14 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config + enable_compile = enable_if is None or enable_if(vllm_config) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() or _should_ignore_torch_compile( - self.__class__) + self.__class__) or not enable_compile if self.do_not_compile: return @@ -267,8 +284,24 @@ def _support_torch_compile( code.co_filename) return inline_call(parent, func, args, kwargs) + # Disable the C++ compilation of symbolic shape guards. C++-fication + # of symbolic shape guards can improve guard overhead. But, since + # vllm skip guards anyways, setting this flag to False can improve + # compile time. + dynamo_config_patches = {} + try: + _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards + dynamo_config_patches[ + "enable_cpp_symbolic_shape_guards"] = False + except AttributeError: + # Note: this config is not available in torch 2.6, we can skip + # if the config doesn't exist + logger.debug( + "enable_cpp_symbolic_shape_guards config not available") + with patch.object(InliningInstructionTranslator, 'inline_call', - patched_inline_call): + patched_inline_call), torch._dynamo.config.patch( + **dynamo_config_patches): output = self.compiled_callable(*args, **kwargs) return output diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 286221d32c1ee..60ae143318790 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -9,6 +9,7 @@ import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from vllm.logger import init_logger +from vllm.platforms import current_platform from .fx_utils import is_func from .vllm_inductor_pass import VllmInductorPass @@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass): """ def __call__(self, graph: torch.fx.Graph): + # XPU does not support auto-functionalization yet. + # Will enable this when switch to vllm-xpu-kernels. + if current_platform.is_xpu(): + logger.debug("XPU platform does not support fix functionalization" + "pass currently.") + return + self.begin() self.dump_graph(graph, "before_fix_functionalization") diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 3dec939c28351..0d8d562514e31 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,7 +12,8 @@ from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe @@ -21,6 +22,7 @@ from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def empty_bf16(*args, **kwargs): @@ -31,42 +33,13 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") +def empty_i32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") + + RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default - -class QuantKey(NamedTuple): - """ - Named tuple for identifying the type of quantization. - dtype: quantized data type - static: static quantization if True, dynamic if False - group_shape: quantization group shape - symmetric: symmetric if True, asymmetric if False - - TODO(luka) use QuantDescriptor once standardized: - https://github.com/vllm-project/vllm/issues/8913 - - """ - dtype: torch.dtype - static: bool - group_shape: GroupShape - symmetric: bool = True - - def __str__(self): - group_shape = ('per_tensor' - if self.group_shape == GroupShape.PER_TENSOR else - ('per_token' if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape))) - - return (f"QuantKey({'static' if self.static else 'dynamic'}," - f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape}," - f"{'a' if not self.symmetric else ''}symmetric)") - - -kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True) -kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True) -kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True) - QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 @@ -75,6 +48,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[ + kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): @@ -187,11 +163,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): @@ -244,11 +218,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, @@ -337,10 +309,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=False, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) @@ -435,10 +407,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=True, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index a40a8caf34a88..f942afe6a28ee 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch import torch._inductor.pattern_matcher as pm from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -9,37 +11,43 @@ from torch._subclasses.fake_tensor import (FakeTensorMode, unset_fake_temporarily) from vllm.attention import Attention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform +from vllm.utils import round_up -from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32 +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + ATTN_OP = torch.ops.vllm.unified_attention_with_output.default RESHAPE_OP = torch.ops.aten.reshape.default -class AttentionStaticQuantPattern: +class AttentionQuantPattern(ABC): + """ + The base class for Attn+Quant fusions. + Should not be used directly. + """ def __init__( self, - layer_name: str, - num_heads: int, - head_size: int, - quant_dtype: torch.dtype, - symmetric=True, + layer: Attention, + quant_key: QuantKey, ): - self.layer_name = layer_name - self.num_heads = num_heads - self.head_size = head_size - self.quant_dtype = quant_dtype - self.quant_key = QuantKey(dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric) + self.layer = layer + self.layer_name = layer.layer_name + self.num_heads = layer.num_heads + self.head_size = layer.head_size + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype + assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] @@ -48,31 +56,64 @@ class AttentionStaticQuantPattern: kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) - def register_if_supported(self, pm_pass: PatternMatcherPass, - layer: Attention): - if layer.impl.fused_output_quant_supported(self.quant_dtype, - self.quant_key.static, - self.quant_key.group_shape): + @staticmethod + def wrap_trace_fn(process_fx, trace_fn): + + def wrapped(*args, **kwargs): + return process_fx(trace_fn(*args, **kwargs)) + + return wrapped + + @staticmethod + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) + return gm + + def register_if_supported(self, pm_pass: PatternMatcherPass): + if self.layer.impl.fused_output_quant_supported(self.quant_key): self._register(pm_pass) + @abstractmethod + def _register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError + + +class AttentionFp8StaticQuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Fp8StaticQuant. + + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Fp8StaticQuant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ + + def __init__( + self, + layer: Attention, + symmetric: bool = True, + ): + quant_key = QuantKey(dtype=FP8_DTYPE, + scale=kStaticTensorScale, + symmetric=symmetric) + super().__init__(layer, quant_key) + def _register(self, pm_pass: PatternMatcherPass): def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): - view_7 = RESHAPE_OP(output_attn, - [-1, self.num_heads, self.head_size]) - at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=output_attn, layer_name=self.layer_name, - output_scale=None) - attn_out_view = RESHAPE_OP(at1[1], - [-1, self.num_heads * self.head_size]) - + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) at2 = auto_functionalized(self.QUANT_OP, result=output_quant, input=attn_out_view, @@ -82,17 +123,20 @@ class AttentionStaticQuantPattern: def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): - view_7 = RESHAPE_OP(output_quant, - [-1, self.num_heads, self.head_size]) - + # attn output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size], + 0.0, + dtype=self.quant_dtype, + device=q.device) at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=output_attn, layer_name=self.layer_name, - output_scale=scale) - + output_scale=scale, + output_block_scale=None) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) # Need custom fake mode, otherwise tracing happens with real tensors. @@ -102,27 +146,100 @@ class AttentionStaticQuantPattern: empty_bf16(5, self.num_heads, self.head_size), # q empty_bf16(5, self.num_heads, self.head_size), # k empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads * self.head_size), # attn_output + empty_bf16(5, self.num_heads, self.head_size), # attn_output self.empty_quant(5, self.num_heads * self.head_size), # quant_output empty_fp32(1, 1) # scale ] - def wrap_trace_fn(process_fx, trace_fn): + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) - def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) - return wrapped +class AttentionNvfp4QuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Nvfp4Quant. - def fx_view_to_reshape(gm: torch.fx.GraphModule): - from torch._inductor.fx_passes.post_grad import view_to_reshape - view_to_reshape(gm) - return gm + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Nvfp4Quant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ + + def __init__(self, layer: Attention): + super().__init__(layer, kNvfp4Quant) + + def _register(self, pm_pass: PatternMatcherPass): + + def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) + at2 = auto_functionalized(self.QUANT_OP, + output=output_quant, + input=attn_out_view, + output_scale=output_scale, + input_scale=input_scale) + output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) + return at2[1], output_scale_view + + def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + # attention output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size // 2], + 0.0, + dtype=self.quant_dtype, + device=q.device) + # attention output block scale + output_scale_view = torch.ops.aten.view.dtype( + output_scale, FP8_DTYPE) + at2 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=input_scale, + output_block_scale=output_scale_view) + output = RESHAPE_OP(at2[1], + [-1, self.num_heads * self.head_size // 2]) + return output, at2[2] + + # Need custom fake mode, otherwise tracing happens with real tensors. + # That would not work for the unified_attention custom op. + with unset_fake_temporarily(), FakeTensorMode(): + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # output_attn + self.empty_quant(5, self.num_heads * self.head_size // + 2), # output_quant + empty_i32(128, + round_up(self.num_heads * self.head_size // 16, + 4)), # output_scale + empty_fp32(1, 1), # input_scale + ] pm.register_replacement( pattern, replacement, inputs, - wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass) + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) class AttnFusionPass(VllmInductorPass): @@ -140,30 +257,39 @@ class AttnFusionPass(VllmInductorPass): def __init__(self, config: VllmConfig): super().__init__(config) - self.static_fwd_ctx = config.compilation_config.static_forward_context self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") - for key, layer in self.static_fwd_ctx.items(): - pattern = AttentionStaticQuantPattern(key, layer.num_heads, - layer.head_size, - current_platform.fp8_dtype()) - pattern.register_if_supported(self.patterns, layer) - if len(self.static_fwd_ctx) == 0: + attn_layers = get_layers_from_vllm_config(config, Attention) + for layer_name, layer in attn_layers.items(): + pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8.register_if_supported(self.patterns) + + pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) + pattern_nvfp4.register_if_supported(self.patterns) + + if len(attn_layers) == 0: logger.warning( - "Attention + quant fusion is enabled, but " - "CompilationConfig.static_forward_context is empty. " - "Cannot access attention layers so no fusion " - "patterns were registered.") + "Attention + quant fusion is enabled, but no attention layers " + "were found in CompilationConfig.static_forward_context " + "so no fusion patterns were registered.") def __call__(self, graph: torch.fx.graph.Graph) -> None: self.begin() self.dump_graph(graph, "before_attn_fusion") count = self.patterns.apply(graph) + + # TODO: Move this to pass_manager.py after the fx graph broken issue + # has been resolved. + # see https://github.com/vllm-project/vllm/issues/23091 + graph.eliminate_dead_code() + logger.debug("Fused quantization onto %s attention nodes", count) self.dump_graph(graph, "after_attn_fusion") self.end_and_log() def uuid(self): - return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern) + return VllmInductorPass.hash_source(self, AttentionQuantPattern, + AttentionFp8StaticQuantPattern, + AttentionNvfp4QuantPattern) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 280ae60c91ff4..e3fb6d796def5 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -33,7 +33,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) -from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig +from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, + ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.utils import ConfigType, config from vllm.logger import init_logger @@ -62,6 +63,7 @@ if TYPE_CHECKING: QuantizationConfig) from vllm.model_executor.model_loader import LoadFormats from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.sample.logits_processor import LogitsProcessor HfOverrides = Union[dict, Callable[[type], type]] else: @@ -72,6 +74,7 @@ else: BaseModelLoader = Any LoadFormats = Any TensorizerConfig = Any + LogitsProcessor = Any HfOverrides = Union[dict[str, Any], Callable[[type], type]] me_quant = LazyLoader("model_executor", globals(), @@ -189,7 +192,17 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: yield a, b a = b - cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + try: + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + except (OSError, KeyError, TypeError): + # HACK: Python 3.13+ workaround - set missing __firstlineno__ + # Workaround can be removed after we upgrade to pydantic==2.12.0 + with open(inspect.getfile(cls)) as f: + for i, line in enumerate(f): + if f"class {cls.__name__}" in line and ":" in line: + cls.__firstlineno__ = i + 1 + break + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] if not isinstance(cls_node, ast.ClassDef): raise TypeError("Given object was not a class.") @@ -244,8 +257,14 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] -LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", - "processed_logits"] +MMEncoderTPMode = Literal["weights", "data"] + + +class LogprobsMode(enum.Enum): + RAW_LOGITS = "raw_logits" + RAW_LOGPROBS = "raw_logprobs" + PROCESSED_LOGITS = "processed_logits" + PROCESSED_LOGPROBS = "processed_logprobs" @config @@ -349,12 +368,13 @@ class ModelConfig: specified in `SamplingParams`. The default value comes the default for the OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * vocab_size) logprobs are allowed to be returned and it may cause OOM.""" - logprobs_mode: LogprobsMode = "raw_logprobs" + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. - Raw means the values before applying logit processors, like bad words. - Processed means the values after applying such processors. + Raw means the values before applying any logit processors, like bad words. + Processed means the values after applying all processors, including + temperature and top_k/top_p. """ disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding @@ -417,7 +437,7 @@ class ModelConfig: from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. """ - mm_processor_cache_gb: int = 4 + mm_processor_cache_gb: float = 4 """The size (in GiB) of the multi-modal processor cache, which is used to avoid re-processing past multi-modal inputs. @@ -426,6 +446,19 @@ class ModelConfig: `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. Set to `0` to disable this cache completely (not recommended).""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP.""" override_neuron_config: dict[str, Any] = field(default_factory=dict) """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to @@ -465,6 +498,9 @@ class ModelConfig: - "transformers" will use the Transformers model implementation.""" override_attention_dtype: Optional[str] = None """Override dtype for attention""" + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None + """One or more logits processors' fully-qualified class names or class + definitions""" def compute_hash(self) -> str: """ @@ -836,22 +872,25 @@ class ModelConfig: def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self._model_info.supports_multimodal: + if (self.mm_encoder_tp_mode == "data" and + not self._model_info.supports_multimodal_encoder_tp_data): + logger.warning_once( + "This model does not support `--mm-encoder-tp-mode data`. " + "Falling back to `--mm-encoder-tp-mode weights`.") + self.mm_encoder_tp_mode = "weights" + return MultiModalConfig( limit_per_prompt=self.limit_mm_per_prompt, media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, interleave_mm_strings=self.interleave_mm_strings, - skip_mm_profiling=self.skip_mm_profiling) + skip_mm_profiling=self.skip_mm_profiling, + ) return None - def set_mm_processor_cache_gb(self, value: int) -> None: - mm_config = self.get_multimodal_config() - - self.mm_processor_cache_gb = value - mm_config.mm_processor_cache_gb = value - def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -1081,9 +1120,20 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS optimized_quantization_methods = [ - "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", - "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc" + "fp8", + "modelopt", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "fbgemm_fp8", + "compressed-tensors", + "experts_int8", + "quark", + "modelopt_fp4", + "bitblas", + "gptq_bitblas", + "inc", + "petit_nvfp4", ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, @@ -1106,7 +1156,6 @@ class ModelConfig: # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). overrides = [ - "marlin", "bitblas", "gptq_marlin_24", "gptq_marlin", @@ -1116,6 +1165,7 @@ class ModelConfig: "moe_wna16", "modelopt", "modelopt_fp4", + "petit_nvfp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides @@ -1448,7 +1498,8 @@ class ModelConfig: from vllm.distributed.utils import get_pp_indices if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp"): + or self.hf_config.model_type == "glm4_moe_mtp" + or self.hf_config.model_type == "ernie_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1647,31 +1698,6 @@ class ModelConfig: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None - @property - def processor_return_mm_hashes(self) -> bool: - """Whether the multi-modal processor should output hashes.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - - @property - def enable_mm_processor_cache(self) -> bool: - """Whether the multi-modal processor cache should be enabled.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - - def get_mm_input_cache_gb(self) -> int: - mm_config = self.multimodal_config - if mm_config is None: - return 0 - - return envs.VLLM_MM_INPUT_CACHE_GIB - @property def is_cross_encoder(self) -> bool: return (self._model_info.supports_cross_encoding @@ -1896,7 +1922,8 @@ class DeviceConfig: SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp"] + "mlp_speculator", "draft_model", "deepseek_mtp", + "ernie_mtp"] @config @@ -2029,6 +2056,16 @@ class SpeculativeConfig: "architectures": ["Glm4MoeMTPModel"] }) + if hf_config.model_type == "ernie4_5_moe": + hf_config.model_type = "ernie_mtp" + if hf_config.model_type == "ernie_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["ErnieMTPModel"] + }) + return hf_config + return hf_config def __post_init__(self): @@ -2047,8 +2084,8 @@ class SpeculativeConfig: if self.target_model_config and \ (self.target_model_config.hf_text_config.model_type \ == "deepseek_v3" or - self.target_model_config.hf_text_config.model_type \ - == "mimo"): + self.target_model_config.hf_text_config.model_type in + ("mimo","ernie4_5_moe")): # use the draft model from the same model: self.model = self.target_model_config.model elif self.method in ("ngram", "[ngram]"): @@ -2146,6 +2183,15 @@ class SpeculativeConfig: "one layer. Might need some code changes " \ "to support multiple layers." ) + elif (self.draft_model_config.hf_config.model_type == + "ernie_mtp"): + self.method = "ernie_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Ernie MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" raise NotImplementedError( @@ -2361,7 +2407,7 @@ class SpeculativeConfig: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp") + return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") def __repr__(self) -> str: method = self.method @@ -2500,7 +2546,7 @@ class MultiModalConfig: `{"num_crops": 4}`. """ - mm_processor_cache_gb: int = 4 + mm_processor_cache_gb: float = 4 """ The size (in GiB) of the multi-modal processor cache, which is used to @@ -2511,6 +2557,22 @@ class MultiModalConfig: Set to `0` to disable this cache completely (not recommended). """ + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """ + Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP. + """ + interleave_mm_strings: bool = False """ Enable fully interleaved support for multimodal prompts. @@ -2518,7 +2580,7 @@ class MultiModalConfig: skip_mm_profiling: bool = False """ - When enabled, skips multimodal memory profiling and only profiles with + When enabled, skips multimodal memory profiling and only profiles with language backbone model during engine initialization. This reduces engine startup time but shifts the responsibility to users for @@ -2581,24 +2643,24 @@ class PoolerConfig: ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the embeddings outputs. + Whether to normalize the embeddings outputs. """ dimensions: Optional[int] = None """ - Reduce the dimensions of embeddings if model + Reduce the dimensions of embeddings if model support matryoshka representation. """ ## for classification models activation: Optional[bool] = None """ - Whether to apply activation function to the classification outputs. + Whether to apply activation function to the classification outputs. """ ## for reward models softmax: Optional[bool] = None """ - Whether to apply softmax to the reward outputs. + Whether to apply softmax to the reward outputs. """ step_tag_id: Optional[int] = None """ @@ -2624,9 +2686,9 @@ class PoolerConfig: max_embed_len: Optional[int] = None """ - Maximum input length allowed for embedding generation. When set, allows + Maximum input length allowed for embedding generation. When set, allows inputs longer than max_embed_len to be accepted for embedding models. - This parameter enables accepting long inputs without requiring + This parameter enables accepting long inputs without requiring VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds max_embed_len, it will be handled according to the original max_model_len validation logic. Defaults to None (i.e. set to max_model_len). @@ -2980,7 +3042,8 @@ def get_served_model_name(model: str, return served_model_name -GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"] +GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines", + "lm-format-enforcer"] @config @@ -3535,15 +3598,6 @@ class VllmConfig: # in V0 means the compilation level wins out. self.compilation_config.level = CompilationLevel.NO_COMPILATION - # if cudagraph_mode is not explicitly set by users, set default value - if self.compilation_config.cudagraph_mode is None: - if envs.VLLM_USE_V1 and self.compilation_config.level \ - == CompilationLevel.PIECEWISE: - self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE - else: - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - # async tp is built on top of sequence parallelism # and requires it to be enabled. if self.compilation_config.pass_config.enable_async_tp: @@ -3552,14 +3606,28 @@ class VllmConfig: if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") - # disable cudagraph when enforce eager execution - if self.model_config is not None and self.model_config.enforce_eager: - logger.info("Cudagraph is disabled under eager mode") - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - elif envs.VLLM_USE_V1: - self.compilation_config.cudagraph_num_of_warmups = 1 + if current_platform.is_cuda_alike() or current_platform.is_xpu(): + # if cudagraph_mode is not explicitly set by users, set default + # value + if self.compilation_config.cudagraph_mode is None: + if envs.VLLM_USE_V1 and self.compilation_config.level \ + == CompilationLevel.PIECEWISE: + self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - self._set_cudagraph_sizes() + # disable cudagraph when enforce eager execution + if self.model_config is not None and \ + self.model_config.enforce_eager: + logger.info("Cudagraph is disabled under eager mode") + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + elif envs.VLLM_USE_V1: + self.compilation_config.cudagraph_num_of_warmups = 1 + + self._set_cudagraph_sizes() + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE if self.cache_config.cpu_offload_gb > 0 and \ self.compilation_config.level != CompilationLevel.NO_COMPILATION \ @@ -3595,9 +3663,6 @@ class VllmConfig: logger.info(reason) self.scheduler_config.chunked_prefill_enabled = False self.scheduler_config.long_prefill_token_threshold = 0 - self.scheduler_config.max_num_batched_tokens = max( - self.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.cache_config is not None: self.cache_config.enable_prefix_caching = False @@ -3618,7 +3683,7 @@ class VllmConfig: current_platform.check_and_update_config(self) # final check of cudagraph mode after platform-specific update - if envs.VLLM_USE_V1: + if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \ and self.model_config is not None and \ not self.model_config.disable_cascade_attn: diff --git a/vllm/config/cache.py b/vllm/config/cache.py index ae11dec3ca5e2..a9550d4390ad6 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -116,7 +116,7 @@ class CacheConfig: In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), some layers can skip tokens corresponding to prefill. This flag enables attention metadata for eligible layers to be overriden with metadata - necessary for implementating this optimization in some models (e.g. Gemma3n) + necessary for implementing this optimization in some models (e.g. Gemma3n) """ def compute_hash(self) -> str: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 56a2183f8e2c1..56aa00a30d3ae 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -225,7 +225,8 @@ class CompilationConfig: # CudaGraph compilation cudagraph_mode: Optional[CUDAGraphMode] = None """ - The mode of the cudagraph. + The mode of the cudagraph: + - NONE, no cudagraph capture. - PIECEWISE. (v1 default) - FULL. @@ -336,6 +337,8 @@ class CompilationConfig: "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", + "vllm.mamba_mixer", + "vllm.short_conv", ] def compute_hash(self) -> str: @@ -382,13 +385,10 @@ class CompilationConfig: if pass_config_exclude: exclude["pass_config"] = pass_config_exclude - # The cast to string is necessary because Pydantic is mocked in docs - # builds and sphinx-argparse doesn't know the return type of decode() - return str( - TypeAdapter(CompilationConfig).dump_json( - self, - exclude=exclude, # type: ignore[arg-type] - exclude_unset=True).decode()) + return TypeAdapter(CompilationConfig).dump_json( + self, + exclude=exclude, # type: ignore[arg-type] + exclude_unset=True).decode() __str__ = __repr__ diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index bac1e63800d7b..9ea883d4a03cd 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -15,7 +15,7 @@ import vllm.envs as envs from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, get_open_port +from vllm.utils import cuda_device_count_stateless, get_open_ports_list if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv @@ -32,6 +32,31 @@ logger = init_logger(__name__) DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] +@config +@dataclass +class EPLBConfig: + """Configuration for Expert Parallel Load Balancing (EP).""" + + window_size: int = 1000 + """Window size for expert load recording.""" + step_interval: int = 3000 + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `lb_window_size` steps will be used for rearranging experts. + """ + + num_redundant_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" + + log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ + + @config @dataclass class ParallelConfig: @@ -75,22 +100,24 @@ class ParallelConfig: """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False """Enable expert parallelism load balancing for MoE layers.""" - num_redundant_experts: int = 0 - """Number of redundant experts to use for expert parallelism.""" - eplb_window_size: int = 1000 - """Window size for expert load recording.""" - eplb_step_interval: int = 3000 - """ - Interval for rearranging experts in expert parallelism. - - Note that if this is greater than the EPLB window size, only the metrics - of the last `eplb_window_size` steps will be used for rearranging experts. - """ - eplb_log_balancedness: bool = False - """ - Log the balancedness each step of expert parallelism. - This is turned off by default since it will cause communication overhead. - """ + eplb_config: EPLBConfig = field(default_factory=EPLBConfig) + """Expert parallelism configuration.""" + num_redundant_experts: Optional[int] = None + """`num_redundant_experts` is deprecated and has been replaced with + `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. + Please use `eplb_config.num_redundant_experts` instead.""" + eplb_window_size: Optional[int] = None + """`eplb_window_size` is deprecated and has been replaced with + `eplb_config.window_size`. This will be removed in v0.12.0. + Please use `eplb_config.window_size` instead.""" + eplb_step_interval: Optional[int] = None + """`eplb_step_interval` is deprecated and has been replaced with + `eplb_config.step_interval`. This will be removed in v0.12.0. + Please use `eplb_config.step_interval` instead.""" + eplb_log_balancedness: Optional[bool] = None + """`eplb_log_balancedness` is deprecated and has been replaced with + `eplb_config.log_balancedness`. This will be removed in v0.12.0. + Please use `eplb_config.log_balancedness` instead.""" max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model @@ -109,7 +136,8 @@ class ParallelConfig: placement_group: Optional[PlacementGroup] = None """ray distributed model workers placement group.""" - distributed_executor_backend: Optional[Union[DistributedExecutorBackend, + distributed_executor_backend: Optional[Union[str, + DistributedExecutorBackend, type[ExecutorBase]]] = None """Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product @@ -137,9 +165,10 @@ class ParallelConfig: rank: int = 0 """Global rank in distributed setup.""" - enable_multimodal_encoder_data_parallel: bool = False - """ Use data parallelism instead of tensor parallelism for vision encoder. - Only support LLama4 for now""" + _data_parallel_master_port_list: list[int] = field(default_factory=list) + """List of open port auto-queried for data parallel messaging. + Set to be private as it's not intended to be configured by users. + """ @property def world_size_across_dp(self) -> int: @@ -153,11 +182,15 @@ class ParallelConfig: processes that is related to data parallelism, e.g. both in the worker and in the engine, which can live in different processes. To avoid port conflicts, we - increment the port number each time we need to initialize a - new process group related to data parallelism. + pop a new port from the prepared port list each time we need to + initialize a new process group related to data parallelism. """ - answer = self.data_parallel_master_port - self.data_parallel_master_port += 1 + if self._data_parallel_master_port_list: + answer = self._data_parallel_master_port_list.pop() + else: + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + return answer def stateless_init_dp_group(self) -> ProcessGroup: @@ -241,6 +274,38 @@ class ParallelConfig: return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: + # Forward deprecated fields to their new location + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = ( + self.num_redundant_experts) + logger.warning_once( + "num_redundant_experts is deprecated and has been replaced " + "with eplb_config.num_redundant_experts. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + logger.warning_once( + "eplb_window_size is deprecated and has been replaced " + "with eplb_config.window_size. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + logger.warning_once( + "eplb_step_interval is deprecated and has been replaced " + "with eplb_config.step_interval. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + logger.warning_once( + "eplb_log_balancedness is deprecated and has been replaced " + "with eplb_config.log_balancedness. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + + # Continue with the rest of the initialization self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size @@ -251,7 +316,10 @@ class ParallelConfig: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. - self.data_parallel_master_port = get_open_port() + if not self._data_parallel_master_port_list: + self._data_parallel_master_port_list = get_open_ports_list(5) + self.data_parallel_master_port = \ + self._data_parallel_master_port_list.pop() if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( @@ -279,10 +347,10 @@ class ParallelConfig: raise ValueError( "Expert parallelism load balancing is only supported on " "CUDA devices now.") - if self.num_redundant_experts < 0: + if self.eplb_config.num_redundant_experts < 0: raise ValueError( "num_redundant_experts must be non-negative, but got " - f"{self.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}.") if not self.enable_expert_parallel: raise ValueError( "enable_expert_parallel must be True to use EPLB.") @@ -293,10 +361,10 @@ class ParallelConfig: f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." ) else: - if self.num_redundant_experts != 0: + if self.eplb_config.num_redundant_experts != 0: raise ValueError( "num_redundant_experts should be used with EPLB." - f"{self.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. @@ -342,23 +410,22 @@ class ParallelConfig: def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( isinstance(self.distributed_executor_backend, type) - and self.distributed_executor_backend.uses_ray) + and getattr(self.distributed_executor_backend, "uses_ray", False)) @model_validator(mode='after') def _verify_args(self) -> Self: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase from vllm.platforms import current_platform - if self.distributed_executor_backend not in ( - "ray", "mp", "uni", - "external_launcher", None) and not (isinstance( + if self.distributed_executor_backend is not None and not isinstance( + self.distributed_executor_backend, str) and not (isinstance( self.distributed_executor_backend, type) and issubclass( self.distributed_executor_backend, ExecutorBase)): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " - "values are 'ray', 'mp' 'uni', 'external_launcher' or" - " custom ExecutorBase subclass.") + "values are 'ray', 'mp' 'uni', 'external_launcher', " + " custom ExecutorBase subclass or its import path.") if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 4ec5a775f465c..cbfa4d7ff3c4c 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): with num_lookahead_slots. Args: - sequence_group (SequenceGroup): The sequence group to swap in. + seq_group (SequenceGroup): The sequence group to swap in. num_lookahead_slots (int): Number of lookahead slots used in speculative decoding, default to 0. @@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): Args: seq_group (SequenceGroup): The sequence group to swap out. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. Returns: bool: Whether it's possible to swap out current sequence group. @@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): swapping out the given sequence_group with num_lookahead_slots. Args: - sequence_group (SequenceGroup): The sequence group to swap out. + seq_group (SequenceGroup): The sequence group to swap out. Returns: List[Tuple[int, int]]: The mapping of swapping block from @@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): on to the 'device'. Args: - sequence_group (SequenceGroup): The sequence group to swap in/out. + seq_group (SequenceGroup): The sequence group to swap in/out. device (Device): device to swap the 'seq_group' on. status (SequenceStatus): The status of sequence which is needed for action. RUNNING for swap out and SWAPPED for swap in diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 942e866ed97ee..7963fb15c4191 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -152,8 +152,13 @@ class CuMemAllocator: self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag self.allocator_and_pools: dict[str, Any] = {} + # Creating strong references to the two callbacks here to prevent + # these ephemeral bound-method objects being garbage collected. + # See discussions in https://github.com/vllm-project/vllm/pull/22724 + self.python_malloc_callback = self._python_malloc_callback + self.python_free_callback = self._python_free_callback - def python_malloc_callback(self, allocation_handle: HandleType) -> None: + def _python_malloc_callback(self, allocation_handle: HandleType) -> None: """ Internal method to store the allocation data when memory is allocated in the memory pool.""" @@ -162,7 +167,7 @@ class CuMemAllocator: allocation_handle, self.current_tag) return - def python_free_callback(self, ptr: int) -> HandleType: + def _python_free_callback(self, ptr: int) -> HandleType: """ Internal method to look up the allocation data when memory is freed in the memory pool.""" @@ -212,9 +217,9 @@ class CuMemAllocator: def wake_up(self, tags: Optional[list[str]] = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU + All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory. - + :param tags: The tags of the memory allocation that will be loaded back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py similarity index 93% rename from vllm/distributed/device_communicators/custom_all_reduce_utils.py rename to vllm/distributed/device_communicators/all_reduce_utils.py index 7c6001e870392..5c64e7d5c4ba3 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless, logger = init_logger(__name__) +MiB = 1024 * 1024 +# Max size for each world size in case symmetric memory is available +# For different SM architectures +CUSTOM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: MiB // 2, # 512 KB + 8: MiB // 4, # 256 KB + }, + "10.0": { + 2: 2 * MiB, # 2 MB + 4: 2 * MiB, # 2 MB + 6: 2 * MiB, # 2 MB + 8: 2 * MiB, # 2 MB + } +} + +SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: 64 * MiB, # 64 MB + 8: 64 * MiB, # 64 MB + }, + "10.0": { + 2: 8 * MiB, # 8 MB + 4: 32 * MiB, # 32 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + } +} + def producer(batch_src: Sequence[int], producer_queue, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 127a340fc6c6d..9e5aa4e4c2a89 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -105,7 +105,8 @@ class DeviceCommunicatorBase: # we initialize the all2all manager used in expert parallel. use_ep = config.parallel_config.data_parallel_size > 1 - self.use_all2all = "ep" in unique_name and use_ep + self.is_ep_communicator = "ep" in unique_name + self.use_all2all = self.is_ep_communicator and use_ep self.all2all_manager: Optional[All2AllManagerBase] = None def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: @@ -246,7 +247,7 @@ class DeviceCommunicatorBase: """ Prepare the communication buffer for the model. """ - if not self.use_all2all: + if not self.is_ep_communicator: return moe_modules = [ @@ -254,7 +255,7 @@ class DeviceCommunicatorBase: if module.__class__.__name__ == "FusedMoE" ] for module in moe_modules: - module.quant_method.init_prepare_finalize(module.moe_config) + module.quant_method.init_prepare_finalize() def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 66d4940c9cec5..eef3f9f75f9f1 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase): PyNcclCommunicator) from vllm.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce) + from vllm.distributed.device_communicators.symm_mem import ( + SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase): self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None + self.symm_mem_comm: Optional[SymmMemCommunicator] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase): # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase): out = ca_comm.custom_all_reduce(input_) assert out is not None return out + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None out = pynccl_comm.all_reduce(input_) @@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase): dtype=input_tensor.dtype, device=input_tensor.device) - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase): device=input_tensor.device) if sizes is not None: - pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) + pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) else: - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 8dfb7959a510d..80aca81234eb0 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.device_communicators.custom_all_reduce_utils import ( - gpu_p2p_access_check) +from vllm.distributed.device_communicators.all_reduce_utils import ( + CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform @@ -109,7 +109,13 @@ class CustomAllreduce: # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - + device_capability = current_platform.get_device_capability( + ).as_version_str() + if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): + max_size = min( + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], + max_size) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py new file mode 100644 index 0000000000000..d907e1b833d04 --- /dev/null +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.all_reduce_utils import ( + SYMM_MEM_ALL_REDUCE_MAX_SIZES) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + symm_mem_available = True +except ImportError: + symm_mem_available = False + +logger = init_logger(__name__) + + +class SymmMemCommunicator: + _WORLD_SIZES_MULTIMEM = { + "9.0": [4, 6, 8], + "10.0": [6, 8], + } + + def __init__(self, group: ProcessGroup, device: Union[int, str, + torch.device]): + self.disabled = True + + if not symm_mem_available: + return + + if not current_platform.is_cuda(): + logger.warning("SymmMemCommunicator: symmetric " + "memory is not available.") + return + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + torch.cuda.set_device(device) + self.dtype = torch.bfloat16 + self.device = device + self.group = group + self.world_size = dist.get_world_size(self.group) + self.device_capability = current_platform.get_device_capability( + ).as_version_str() + if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: + logger.warning( + "SymmMemCommunicator: Device capability %s not supported, " + "communicator is not available.", + self.device_capability, + ) + return + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability]: + logger.warning( + "SymmMemCommunicator: World size %d not supported, " + "communicator is not available.", + self.world_size, + ) + return + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size] + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + if handle.multicast_ptr == 0: + logger.warning("SymmMemCommunicator: symmetric memory " + "multicast operations are not supported.") + return + self.disabled = False + + def should_use_symm_mem(self, inp: torch.Tensor): + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + if inp_size % 4 != 0: + return False + return inp_size < self.max_size + + def all_reduce( + self, + inp: torch.Tensor, + *, + out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + if not self.should_use_symm_mem(inp): + return None + if out is None: + out = torch.empty_like(inp) + self.buffer[:inp.numel()].copy_(inp.view(-1)) + if self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability]: + torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + else: + torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + out.copy_(self.buffer[:inp.numel()].view(out.shape)) + return out diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index c60a7a7eb25cf..942dd67f065dc 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.tpu import USE_TPU_COMMONS from .base_device_communicator import DeviceCommunicatorBase @@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config( logger = init_logger(__name__) -if current_platform.is_tpu(): - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.runtime as xr - from torch_xla._internal import pjrt - from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups) - - if USE_RAY: - from vllm.executor import ray_utils +if not USE_TPU_COMMONS: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + if current_platform.is_tpu(): + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) + if USE_RAY: + from vllm.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): @@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase): return xm.all_gather(input_, dim=dim) -try: +if USE_TPU_COMMONS: from tpu_commons.distributed.device_communicators import ( TpuCommunicator as TpuCommonsCommunicator) TpuCommunicator = TpuCommonsCommunicator # type: ignore -except ImportError: - logger.info("tpu_commons not found, using vLLM's TpuCommunicator") - pass diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 979f2a06cec9f..042acf40d67c2 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -244,7 +244,7 @@ class EplbState: dtype=torch.int32, device=device, ) - expert_load_window_size = parallel_config.eplb_window_size + expert_load_window_size = parallel_config.eplb_config.window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, model.num_physical_experts), @@ -253,7 +253,7 @@ class EplbState: ) # Set the initial progress of rearrangement to 3/4 - eplb_step_interval = parallel_config.eplb_step_interval + eplb_step_interval = parallel_config.eplb_config.step_interval expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md index 349d3dfbd84fc..39377aabcce3a 100644 --- a/vllm/distributed/kv_transfer/README.md +++ b/vllm/distributed/kv_transfer/README.md @@ -2,7 +2,7 @@ # Distributed KV cache transfer This folder implements distributed KV cache transfer across vLLM instances. -Currently the main usecase is for disaggregated prefilling. +Currently the main use case is for disaggregated prefilling. ## Abstractions @@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions: Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. -NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed communication service already supports key-value-based lookup (like redis or RDMA database). diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 07fcdecac6276..5601ee74be110 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC): Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). - Args: kv_caches: - dictionary of layer names, kv cache + Args: + kv_caches: dictionary of layer names, kv cache """ return diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4f51229ffbd26..6608d2a4a9e09 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -686,9 +686,6 @@ class NixlConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - _, first_kv_cache = next(iter(kv_caches.items())) - kv_elem_size = first_kv_cache.element_size() - if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( @@ -701,66 +698,16 @@ class NixlConnectorWorker: "host_xfer_buffer should not be initialized when " f"kv_buffer_device is {self.kv_buffer_device}") - # TODO(tms): Find a more robust way to detect and handle MLA - # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected - # KV memory layout is HND, as opposed to the default NHD. Note that it - # will only affects the strides. For MLA instead, we make require no - # such thing and resort to the standard layout. - use_mla = len(first_kv_cache.shape) == 3 - if self.device_type == "tpu": - assert not use_mla, f"{self.kv_buffer_device} does not support MLA." - assert self._use_pallas_v1, f"attn backend: {self.backend_name}" - # tpu (v1) kv shape per layer: - # (num_blocks, block_size, num_kv_heads * 2, head_size) - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads_x_2, head_dim = block_shape - self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim - elif self.device_type == "cuda": - assert use_mla == self.use_mla - # TODO (NickLucche) not compatible with hybrid allocator. - # Enforce check once it goes live, as a single kv layout - # is expected for xfers. - if use_mla: - # MLA case. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 2 # [block_size, latent_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, kv_latent_dim = block_shape - self.slot_size_bytes = kv_elem_size * kv_latent_dim - else: - # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 4 # [2, block_size, kv_heads, head_dim] - else: - self.num_blocks = first_kv_cache.shape[1] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads, head_dim = block_shape[-3:] - # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim - assert block_size == self.block_size - else: - raise RuntimeError( - f"{self.device_type} ({self.backend_name}) is not supported.") - - # TODO(tms): self.block_len needs to be per-layer for sliding window, - # hybrid attn, etc - # block size in bytes - self.block_len = kv_elem_size * math.prod(block_shape) logger.info( "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " - "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, - self.use_host_buffer, self.num_blocks, block_shape, - first_kv_cache.shape) - self.dst_num_blocks[self.engine_id] = self.num_blocks - self.device_kv_caches = kv_caches - kv_caches_base_addr = [] + "use_host_buffer: %s", self.use_mla, self.kv_buffer_device, + self.use_host_buffer) + caches_data = [] + # With hybrid allocator, layers can share a kv cache tensor + seen_base_addresses = [] + xfer_buffers = (self.host_xfer_buffers + if self.use_host_buffer else kv_caches) # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -770,42 +717,35 @@ class NixlConnectorWorker: # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in xfer_buffers.values(): - # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla \ - or self._use_pallas_v1 or self._use_flashinfer \ - else cache_or_caches + split_k_and_v = not (self.use_mla or self._use_pallas_v1 + or self._use_flashinfer) + tensor_size_bytes = None + for layer_name, cache_or_caches in xfer_buffers.items(): + cache_list = cache_or_caches if split_k_and_v else [ + cache_or_caches + ] + for cache in cache_list: base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len - # NOTE: use tp_rank for device_id since multi-node TP - # is rarely used. - caches_data.append((base_addr, region_len, self.tp_rank, "")) - kv_caches_base_addr.append(base_addr) - self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.numel() * cache.element_size() + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert tensor_size_bytes == curr_tensor_size_bytes, \ + "All kv cache tensors must have the same size" + caches_data.append( + (base_addr, tensor_size_bytes, self.tp_rank, "")) + + self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) - if self.vllm_config.model_config.hf_config.model_type == "llama4": - from transformers import Llama4TextConfig - assert isinstance(self.vllm_config.model_config.hf_text_config, - Llama4TextConfig) - llama4_config = self.vllm_config.model_config.hf_text_config - no_rope_layers = llama4_config.no_rope_layers - chunk_size = llama4_config.attention_chunk_size - chunk_block_size = math.ceil(chunk_size / self.block_size) - for layer_idx in range(self.num_layers): - # no_rope_layers[layer_idx] == 0 means NoPE (global) - # Any other value means RoPE (local chunked) - is_local_attention = no_rope_layers[layer_idx] != 0 - block_window = chunk_block_size if is_local_attention else None - self.block_window_per_layer.append(block_window) - logger.debug("Llama 4 block window per layer mapping: %s", - self.block_window_per_layer) - assert len(self.block_window_per_layer) == self.num_layers - descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) @@ -813,9 +753,20 @@ class NixlConnectorWorker: logger.debug("Done registering descs") self._registered_descs.append(descs) + assert tensor_size_bytes is not None + assert self.num_blocks != 0 + assert tensor_size_bytes % self.num_blocks == 0 + self.block_len = tensor_size_bytes // self.num_blocks + self.slot_size_bytes = self.block_len // self.block_size + if self._use_flashinfer: + assert self.slot_size_bytes % 2 == 0 + self.slot_size_bytes /= 2 + self.device_kv_caches = kv_caches + self.dst_num_blocks[self.engine_id] = self.num_blocks + # Register local/src descr for NIXL xfer. blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: + for base_addr in seen_base_addresses: # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # could create fewer, but then _get_block_descs_ids needs to @@ -836,6 +787,26 @@ class NixlConnectorWorker: self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) + # TODO(mgoin): Hybrid memory allocator is currently diabled for + # models with local attention (Llama 4). Can remove this once enabled. + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + # After KV Caches registered, listen for new connections. metadata = NixlAgentMetadata( engine_id=self.engine_id, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 32d0e43d71afe..2485c57d86ecc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -30,27 +30,19 @@ logger = init_logger(__name__) class ReqMeta: # Request Id request_id: str - # Request tokens - token_ids: torch.Tensor - # Slot mappings, should have the same length as token_ids - slot_mapping: torch.Tensor + # Request block ids + block_ids: torch.Tensor + # Request num tokens + num_tokens: int @staticmethod def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], block_size: int) -> "ReqMeta": - valid_num_tokens = len(token_ids) - token_ids_tensor = torch.tensor(token_ids) block_ids_tensor = torch.tensor(block_ids) - num_blocks = block_ids_tensor.shape[0] - block_offsets = torch.arange(0, block_size) - slot_mapping = block_offsets.reshape((1, block_size)) + \ - block_ids_tensor.reshape((num_blocks, 1)) * block_size - slot_mapping = slot_mapping.flatten()[:valid_num_tokens] - return ReqMeta( request_id=request_id, - token_ids=token_ids_tensor, - slot_mapping=slot_mapping, + block_ids=block_ids_tensor, + num_tokens=len(token_ids), ) @@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1): return def inject_kv_into_layer( - dst_kv_cache_layer: torch.Tensor, - src_kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, + layer: torch.Tensor, + kv_cache: torch.Tensor, + block_ids: torch.Tensor, request_id: str, ) -> None: - """Inject the KV cache into the layer. + """ + Inject KV cache data into a given attention layer tensor. + + This function updates `layer` in-place with values from `kv_cache`, + handling different backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + If the number of provided block IDs does not match the number of KV + blocks, only the overlapping portion is updated, and a warning is + logged. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not - using MLA, [num_pages, page_size, xxx] otherwise. - src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] - otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape - [num_tokens]. - request_id (str): request id for log + layer (torch.Tensor): The attention layer KV tensor to update. + kv_cache (torch.Tensor): The KV cache tensor to inject. + block_ids (torch.Tensor): Indices of the blocks to update. + request_id (str): Request identifier used for logging. + + Returns: + None. The function modifies `layer` in-place. """ - dst_kv_cache_layer_shape = dst_kv_cache_layer.shape - if isinstance(attn_metadata, MLACommonMetadata): - num_pages = dst_kv_cache_layer_shape[0] - page_size = dst_kv_cache_layer_shape[1] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 0) - num_token = src_kv_cache.shape[0] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + if (isinstance(attn_metadata, MLACommonMetadata) + or layer.shape[1] == 2): # MLA or FlashInfer + num_block = kv_cache.shape[0] + self.check_tensors_except_dim(layer, kv_cache, 0) + if len(block_ids) == num_block: + layer[block_ids, ...] = kv_cache else: - dst_kv_cache_layer[slot_mapping[:num_token], - ...] = src_kv_cache + layer[block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", len(block_ids), + num_block, request_id) - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) - else: - num_pages = dst_kv_cache_layer_shape[1] - page_size = dst_kv_cache_layer_shape[2] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - 2, num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 1) - num_token = src_kv_cache.shape[1] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + elif layer.shape[0] == 2: # FlashAttention + num_block = kv_cache.shape[1] + self.check_tensors_except_dim(layer, kv_cache, 1) + if len(block_ids) == num_block: + layer[:, block_ids, ...] = kv_cache else: - dst_kv_cache_layer[:, slot_mapping[:num_token], - ...] = src_kv_cache + layer[:, block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) - - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", len(block_ids), + num_block, request_id) # Get the metadata metadata: KVConnectorMetadata = \ @@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1): if kv_cache is None: continue - kv_cache_layer = kv_cache[ \ - forward_context.virtual_engine] + layer = kv_cache[forward_context.virtual_engine] kv_cache = self.p2p_nccl_engine.recv_tensor( request.request_id + "#" + layer_name) if kv_cache is None: - logger.warning("🚧src_kv_cache is None, %s", - request.request_id) + logger.warning("🚧kv_cache is None, %s", request.request_id) continue - inject_kv_into_layer(kv_cache_layer, kv_cache, - request.slot_mapping, request.request_id) + inject_kv_into_layer(layer, kv_cache, request.block_ids, + request.request_id) def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's @@ -245,16 +230,46 @@ class P2pNcclConnector(KVConnectorBase_V1): assert self.p2p_nccl_engine is not None + def extract_kv_from_layer( + layer: torch.Tensor, + block_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Extract KV cache slices from a given attention layer tensor. + + This function handles multiple backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + Args: + layer (torch.Tensor): The KV cache from the attention layer. + block_ids (torch.Tensor): Indices of blocks to extract. + + Returns: + torch.Tensor: A tensor containing the extracted KV slices. + Returns None if the layout is unsupported. + """ + if (isinstance(attn_metadata, MLACommonMetadata) + or layer.shape[1] == 2): # MLA or FlashInfer + return layer[block_ids, ...] + + if layer.shape[0] == 2: # FlashAttention + return layer[:, block_ids, ...] + + return None + connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, P2pNcclConnectorMetadata) for request in connector_metadata.requests: request_id = request.request_id ip, port = self.parse_request_id(request_id, True) remote_address = ip + ":" + str(port + self._rank) - self.p2p_nccl_engine.send_tensor( - request_id + "#" + layer_name, kv_layer, remote_address, - request.slot_mapping, - isinstance(attn_metadata, MLACommonMetadata)) + + kv_cache = extract_kv_from_layer(kv_layer, request.block_ids) + self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, + kv_cache, remote_address) def wait_for_save(self): if self.is_producer: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index b94f2296dcb36..dfd95548c4632 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -62,8 +62,6 @@ class SendQueueItem: tensor_id: str remote_address: str tensor: torch.Tensor - slot_mapping: torch.Tensor - is_mla: bool class P2pNcclEngine: @@ -202,8 +200,6 @@ class P2pNcclEngine: tensor_id: str, tensor: torch.Tensor, remote_address: typing.Optional[str] = None, - slot_mapping: torch.Tensor = None, - is_mla: bool = False, ) -> bool: if remote_address is None: with self.recv_store_cv: @@ -213,9 +209,7 @@ class P2pNcclEngine: item = SendQueueItem(tensor_id=tensor_id, remote_address=remote_address, - tensor=tensor, - slot_mapping=slot_mapping, - is_mla=is_mla) + tensor=tensor) if self.send_type == "PUT": return self.send_sync(item) @@ -433,9 +427,7 @@ class P2pNcclEngine: if item.remote_address not in self.socks: self.create_connect(item.remote_address) - with self.send_stream: - tensor = self.extract_kv_from_layer(item.is_mla, item.tensor, - item.slot_mapping) + tensor = item.tensor sock = self.socks[item.remote_address] comm, rank = self.comms[item.remote_address] @@ -548,21 +540,3 @@ class P2pNcclEngine: self._send_thread.join() if self._ping_thread is not None: self._ping_thread.join() - - @staticmethod - def extract_kv_from_layer( - is_mla: bool, - layer: torch.Tensor, - slot_mapping: torch.Tensor, - ) -> torch.Tensor: - """Extract the KV cache from the layer. - Assume the shape of the layer is (2, num_pages, page_size, xxx) - if MLA is not used, and (num_pages, page_size, xxx) otherwise. - """ - if is_mla: - num_pages, page_size = layer.shape[0], layer.shape[1] - return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] - - num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py index 02e3bc6274f60..b775276d4a846 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py @@ -99,8 +99,9 @@ class TensorMemoryPool: addr=self.base_address) self.free_lists[self.max_block_size][ initial_block.addr] = initial_block - logger.debug("TensorMemoryPool, base_address:", self.base_address, - self.base_address % self.max_block_size) + + logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d", + self.base_address, self.max_block_size) def allocate(self, size: int) -> int: """Allocates a memory block of at least the requested size. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f8af6d36e0c06..3399d505e3631 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,13 +8,13 @@ import dataclasses import functools import json import sys -import threading from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) +import huggingface_hub import regex as re import torch from pydantic import TypeAdapter, ValidationError @@ -24,25 +24,26 @@ import vllm.envs as envs from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, ConvertOption, DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, + DeviceConfig, DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, ModelConfig, ModelDType, - ModelImpl, MultiModalConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, - RunnerOption, SchedulerConfig, SchedulerPolicy, - SpeculativeConfig, TaskOption, TokenizerMode, - VllmConfig, get_attr_docs, get_field) + LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, + ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, + SchedulerPolicy, SpeculativeConfig, TaskOption, + TokenizerMode, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import is_interleaved +from vllm.transformers_utils.config import get_model_path, is_interleaved from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) +from vllm.v1.sample.logits_processor import LogitsProcessor # yapf: enable @@ -151,9 +152,17 @@ def is_online_quantization(quantization: Any) -> bool: return quantization in ["inc"] +NEEDS_HELP = ( + "--help" in (argv := sys.argv) # vllm SUBCOMMAND --help + or (argv0 := argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND + or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND +) + + @functools.lru_cache(maxsize=30) def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: - cls_docs = get_attr_docs(cls) + # Save time only getting attr docs if we're generating help text + cls_docs = get_attr_docs(cls) if NEEDS_HELP else {} kwargs = {} for field in fields(cls): # Get the set of possible types for the field @@ -171,7 +180,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the help text for the field name = field.name - help = cls_docs[name].strip() + help = cls_docs.get(name, "").strip() # Escape % for argparse help = help.replace("%", "%%") @@ -253,6 +262,9 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: def get_kwargs(cls: ConfigType) -> dict[str, Any]: """Return argparse kwargs for the given Config dataclass. + If `--help` or `mkdocs` are not present in the command line command, the + attribute documentation will not be included in the help output. + The heavy computation is cached via functools.lru_cache, and a deep copy is returned so callers can mutate the dictionary without affecting the cached version. @@ -289,7 +301,7 @@ class EngineArgs: # is intended for expert use only. The API may change without # notice. distributed_executor_backend: Optional[Union[ - DistributedExecutorBackend, + str, DistributedExecutorBackend, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size @@ -303,11 +315,12 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb - num_redundant_experts: int = ParallelConfig.num_redundant_experts - eplb_window_size: int = ParallelConfig.eplb_window_size - eplb_step_interval: int = ParallelConfig.eplb_step_interval - eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness + num_redundant_experts: int = EPLBConfig.num_redundant_experts + eplb_window_size: int = EPLBConfig.window_size + eplb_step_interval: int = EPLBConfig.step_interval + eplb_log_balancedness: bool = EPLBConfig.log_balancedness max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -349,7 +362,8 @@ class EngineArgs: mm_processor_kwargs: Optional[Dict[str, Any]] = \ MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED - mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb + mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields enable_lora: bool = False @@ -432,12 +446,14 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location - enable_multimodal_encoder_data_parallel: bool = \ - ParallelConfig.enable_multimodal_encoder_data_parallel + # DEPRECATED + enable_multimodal_encoder_data_parallel: bool = False + + logits_processors: Optional[list[Union[ + str, type[LogitsProcessor]]]] = ModelConfig.logits_processors + """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - # DEPRECATED - enable_prompt_adapter: bool = False kv_sharing_fast_prefill: bool = \ CacheConfig.kv_sharing_fast_prefill @@ -449,9 +465,18 @@ class EngineArgs: if isinstance(self.compilation_config, dict): self.compilation_config = CompilationConfig( **self.compilation_config) + if isinstance(self.eplb_config, dict): + self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() + # when use hf offline,replace model id to local model path + if huggingface_hub.constants.HF_HUB_OFFLINE: + model_id = self.model + self.model = get_model_path(self.model, self.revision) + logger.info( + "HF_HUB_OFFLINE is True, replace model_id [%s] " \ + "to model_path [%s]",model_id, self.model) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -500,6 +525,7 @@ class EngineArgs: model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) model_group.add_argument("--logprobs-mode", + choices=[f.value for f in LogprobsMode], **model_kwargs["logprobs_mode"]) model_group.add_argument("--disable-sliding-window", **model_kwargs["disable_sliding_window"]) @@ -549,6 +575,8 @@ class EngineArgs: **model_kwargs["model_impl"]) model_group.add_argument("--override-attention-dtype", **model_kwargs["override_attention_dtype"]) + model_group.add_argument("--logits-processors", + **model_kwargs["logits_processors"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -587,7 +615,7 @@ class EngineArgs: **guided_decoding_kwargs["disable_additional_properties"]) guided_decoding_group.add_argument( "--reasoning-parser", - # This choices is a special case because it's not static + # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), **guided_decoding_kwargs["reasoning_backend"]) @@ -647,14 +675,32 @@ class EngineArgs: **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--num-redundant-experts", - **parallel_kwargs["num_redundant_experts"]) - parallel_group.add_argument("--eplb-window-size", - **parallel_kwargs["eplb_window_size"]) - parallel_group.add_argument("--eplb-step-interval", - **parallel_kwargs["eplb_step_interval"]) - parallel_group.add_argument("--eplb-log-balancedness", - **parallel_kwargs["eplb_log_balancedness"]) + parallel_group.add_argument("--eplb-config", + **parallel_kwargs["eplb_config"]) + parallel_group.add_argument( + "--num-redundant-experts", + type=int, + help= + "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-window-size", + type=int, + help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-step-interval", + type=int, + help= + "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-log-balancedness", + action=argparse.BooleanOptionalAction, + help= + "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -670,7 +716,8 @@ class EngineArgs: **parallel_kwargs["worker_extension_cls"]) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", - **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) + action="store_true", + deprecated=True) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -720,6 +767,8 @@ class EngineArgs: multimodal_group.add_argument("--disable-mm-preprocessor-cache", action="store_true", deprecated=True) + multimodal_group.add_argument( + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) @@ -849,12 +898,6 @@ class EngineArgs: parser.add_argument('--disable-log-stats', action='store_true', help='Disable logging statistics.') - parser.add_argument('--enable-prompt-adapter', - action='store_true', - deprecated=True, - help='[DEPRECATED] Prompt adapter has been ' - 'removed. Setting this flag to True or False' - ' has no effect on vLLM behavior.') return parser @@ -894,6 +937,14 @@ class EngineArgs: self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB + if self.enable_multimodal_encoder_data_parallel: + logger.warning( + "--enable-multimodal-encoder-data-parallel` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-encoder-tp-mode data` instead.") + + self.mm_encoder_tp_mode = "data" + return ModelConfig( model=self.model, hf_config_path=self.hf_config_path, @@ -932,6 +983,7 @@ class EngineArgs: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -940,6 +992,7 @@ class EngineArgs: enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, + logits_processors=self.logits_processors, ) def validate_tensorizer_args(self): @@ -1004,7 +1057,7 @@ class EngineArgs: # details from the config directly # no user input required / expected if isinstance(hf_config, SpeculatorsConfig): - # We create one since we dont create one + # We create one since we don't create one self.speculative_config = {} self.speculative_config[ "num_speculative_tokens"] = hf_config.num_lookahead_tokens @@ -1068,12 +1121,13 @@ class EngineArgs: # Set default arguments for V0 or V1 Engine. if use_v1: self._set_default_args_v1(usage_context, model_config) - # Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1 + # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1 if current_platform.is_cpu( ) and current_platform.get_cpu_architecture() in ( - CpuArchEnum.POWERPC, CpuArchEnum.ARM): + CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM): logger.info( - "Chunked prefill is not supported for ARM and POWER CPUs; " + "Chunked prefill is not supported for ARM and POWER " + "and S390X CPUs; " "disabling it for V1 backend.") self.enable_chunked_prefill = False else: @@ -1216,6 +1270,16 @@ class EngineArgs: "Currently, speculative decoding is not supported with " "async scheduling.") + # Forward the deprecated CLI args to the EPLB config. + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = self.num_redundant_experts + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1229,10 +1293,7 @@ class EngineArgs: data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, - num_redundant_experts=self.num_redundant_experts, - eplb_window_size=self.eplb_window_size, - eplb_step_interval=self.eplb_step_interval, - eplb_log_balancedness=self.eplb_log_balancedness, + eplb_config=self.eplb_config, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, @@ -1241,22 +1302,8 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, - enable_multimodal_encoder_data_parallel=self. - enable_multimodal_encoder_data_parallel, ) - if model_config.is_multimodal_model: - dp_supports_mm_processor_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not dp_supports_mm_processor_cache - and model_config.mm_processor_cache_gb > 0): - logger.warning( - "Multi-modal processor cache is disabled because " - "it is not compatible with data parallelism when " - "there does not exist a one-to-one correspondance " - "between API and engine core processes.") - model_config.set_mm_processor_cache_gb(0) - speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -1385,21 +1432,20 @@ class EngineArgs: recommend_to_remove=True) return False - # Need at least Ampere for now (FA support required). - # Skip this check if we are running on a non-GPU platform, - # or if the device capability is not available - # (e.g. in a Ray actor without GPUs). + # Triton v3.3 has f16 conversion regression issue on Turing and Volta, + # which broke fp16 inference + # see: https://github.com/triton-lang/triton/issues/6698 if (current_platform.is_cuda() - and current_platform.get_device_capability() - and current_platform.get_device_capability().major < 8): - _raise_or_fallback(feature_name="Compute Capability < 8.0", - recommend_to_remove=False) + and not current_platform.has_device_capability(80) + and model_config.dtype == torch.float16): + _raise_or_fallback( + feature_name="Compute Capability < 8.0 with FP16", + recommend_to_remove=False) return False - # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": supported = current_platform.is_kv_cache_dtype_supported( - self.kv_cache_dtype) + self.kv_cache_dtype, model_config) if not supported: _raise_or_fallback(feature_name="--kv-cache-dtype", recommend_to_remove=False) @@ -1477,11 +1523,6 @@ class EngineArgs: ############################################################# # Experimental Features - allow users to opt in. - # Signal Handlers requires running in main thread. - if (threading.current_thread() != threading.main_thread() - and _warn_or_fallback("Engine in background thread")): - return False - if self.pipeline_parallel_size > 1: supports_pp = getattr(self.distributed_executor_backend, 'supports_pp', False) @@ -1602,9 +1643,6 @@ class EngineArgs: self.enable_prefix_caching = incremental_prefill_supported logger.info("(%s) prefix caching by default", action) - if not self.enable_chunked_prefill: - self.max_num_batched_tokens = model_config.max_model_len - # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default if self.scheduler_cls == EngineArgs.scheduler_cls: @@ -1692,8 +1730,11 @@ class EngineArgs: self.max_num_batched_tokens = \ default_max_num_batched_tokens[usage_context] else: - self.max_num_batched_tokens = default_max_num_batched_tokens[ - usage_context] + if not self.enable_chunked_prefill: + self.max_num_batched_tokens = model_config.max_model_len + else: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens[usage_context] logger.debug( "Setting max_num_batched_tokens to %d for %s usage context.", self.max_num_batched_tokens, use_context_value) @@ -1732,7 +1773,7 @@ class AsyncEngineArgs(EngineArgs): def add_cli_args(parser: FlexibleArgumentParser, async_args_only: bool = False) -> FlexibleArgumentParser: # Initialize plugin to update the parser, for example, The plugin may - # adding a new kind of quantization method to --quantization argument or + # add a new kind of quantization method to --quantization argument or # a new device to --device argument. load_general_plugins() if not async_args_only: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 73726eeab5fc7..4fb028627a8c4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -486,10 +486,10 @@ class AsyncLLMEngine(EngineClient): _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine def __init__(self, - *args, + *args: Any, log_requests: bool = True, start_engine_loop: bool = True, - **kwargs) -> None: + **kwargs: Any) -> None: if envs.VLLM_USE_V1: raise ValueError( "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. " @@ -998,7 +998,7 @@ class AsyncLLMEngine(EngineClient): await self.abort(request_id) raise - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort a request. Abort a submitted request. If the request is finished or not found, @@ -1007,6 +1007,9 @@ class AsyncLLMEngine(EngineClient): Args: request_id: The unique id of the request. """ + if not isinstance(request_id, str): + raise RuntimeError("Only single-request abort supported in" + " deprecated V0") if not self.is_running: raise AsyncEngineDeadError( "Background loop is not running. If it was running, " diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bbe958351e87c..03c2f0375da42 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -36,6 +36,7 @@ from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) @@ -250,9 +251,13 @@ class LLMEngine: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=processor_only_cache_from_config( + self.model_config, mm_registry), + ) self.model_executor = executor_class(vllm_config=vllm_config) @@ -644,10 +649,10 @@ class LLMEngine: Details: - Set arrival_time to the current time if it is None. - Set prompt_token_ids to the encoded prompt if it is None. - - Create `n` number of [Sequence][vllm.Sequence] objects. - - Create a [SequenceGroup][vllm.SequenceGroup] object - from the list of [Sequence][vllm.Sequence]. - - Add the [SequenceGroup][vllm.SequenceGroup] object to the + - Create `n` number of [Sequence][vllm.sequence.Sequence] objects. + - Create a [SequenceGroup][vllm.sequence.SequenceGroup] object + from the list of [Sequence][vllm.sequence.Sequence]. + - Add the [SequenceGroup][vllm.sequence.SequenceGroup] object to the scheduler. Example: @@ -840,8 +845,8 @@ class LLMEngine: def reset_mm_cache(self) -> bool: """Reset the multi-modal cache.""" - return self.input_preprocessor.mm_registry.reset_processor_cache( - self.model_config) + self.input_preprocessor.clear_cache() + return True def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" @@ -1822,7 +1827,7 @@ class LLMEngine: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index f69f72edf6a52..0bb11328b1db5 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -5,8 +5,8 @@ import asyncio import copy import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, - Optional, Union, cast) +from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, + Mapping, Optional, Union, cast) import cloudpickle import psutil @@ -404,9 +404,13 @@ class MQLLMEngineClient(EngineClient): error_message="Unable to start RPC Server", socket=socket) - async def abort(self, request_id: str): + async def abort(self, request_id: Union[str, Iterable[str]]): """Send an ABORT_REQUEST signal to the RPC Server""" + if not isinstance(request_id, str): + raise RuntimeError("Only single-request abort supported in" + " deprecated V0") + with suppress(MQClientClosedError): await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), socket=self.input_socket) @@ -535,7 +539,7 @@ class MQLLMEngineClient(EngineClient): if request_id in self.output_queues: raise ValueError(f"Request {request_id} already exists") - # 1) Create output queue for this requests. + # 1) Create output queue for this request. queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() self.output_queues[request_id] = queue @@ -647,7 +651,7 @@ class MQLLMEngineClient(EngineClient): # Uses the same I/O as generate requests request = RPCLoadAdapterRequest(lora_request) - # Create output queue for this requests. + # Create output queue for this request. queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() self.output_queues[request.request_id] = queue diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 671e9648a3d0c..5e8ac9c0b3987 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, Mapping, Optional +from typing import AsyncGenerator, Iterable, Mapping, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig, VllmConfig @@ -229,11 +229,12 @@ class EngineClient(ABC): ... @abstractmethod - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort a request. Args: - request_id: The unique id of the request. + request_id: The unique id of the request, + or an iterable of such ids. """ ... @@ -328,3 +329,11 @@ class EngineClient(ABC): drain_timeout: int = 300) -> None: """Scale the engine""" raise NotImplementedError + + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """Perform a collective RPC call to the given path.""" + raise NotImplementedError diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 74c8093f49674..7b11a50642de9 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1330,7 +1330,7 @@ def apply_mistral_chat_template( # mistral-common uses assert statements to stop processing of input # if input does not comply with the expected format. # We convert those assertion errors to ValueErrors so they can be - # are properly caught in the preprocessing_input step + # properly caught in the preprocessing_input step except (AssertionError, MistralCommonException) as e: raise ValueError(str(e)) from e @@ -1345,5 +1345,18 @@ def apply_mistral_chat_template( "template") raise ValueError(str(e)) from e -def random_tool_call_id() -> str: - return f"chatcmpl-tool-{random_uuid()}" +def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): + idx = 0 + for msg in conversation: + if msg['role'] == 'assistant': + tool_calls = msg.get('tool_calls') + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + return idx + +def make_tool_call_id(id_type:str='random', func_name=None, idx=None): + + if id_type=='kimi_k2': + return f'functions.{func_name}:{idx}' + else: + # by default return random + return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/constants.py b/vllm/entrypoints/constants.py new file mode 100644 index 0000000000000..b5bcccc35d6c8 --- /dev/null +++ b/vllm/entrypoints/constants.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared constants for vLLM entrypoints. +""" + +# HTTP header limits for h11 parser +# These constants help mitigate header abuse attacks +H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB +H11_MAX_HEADER_COUNT_DEFAULT = 256 diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index e817f07ef5947..9d587e8669339 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -3,13 +3,16 @@ import json import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Union +from collections.abc import Sequence +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Optional, Union from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm.entrypoints.harmony_utils import ( get_encoding, get_streamable_parser_for_assistant, render_for_completion) from vllm.entrypoints.tool import Tool +from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput if TYPE_CHECKING: @@ -36,6 +39,11 @@ class ConversationContext(ABC): def render_for_completion(self) -> list[int]: pass + @abstractmethod + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + pass + class SimpleContext(ConversationContext): @@ -54,28 +62,45 @@ class SimpleContext(ConversationContext): def render_for_completion(self) -> list[int]: raise NotImplementedError("Should not be called.") + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + pass + class HarmonyContext(ConversationContext): def __init__( self, messages: list, - tool_sessions: dict[str, Tool], + available_tools: list[str], ): self._messages = messages - self.tool_sessions = tool_sessions + self.available_tools = available_tools + self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) - # TODO(woosuk): Implement the following fields. self.num_prompt_tokens = 0 - self.num_cached_tokens = 0 self.num_output_tokens = 0 + # TODO(woosuk): Implement the following fields. + self.num_cached_tokens = 0 self.num_reasoning_tokens = 0 + def _update_num_prompt_tokens(self, output: RequestOutput): + if output.prompt_token_ids and len(output.prompt_token_ids) > 0: + # NOTE: with built-in tools, there might be multiple rounds in + # the conversation, with the full conversation being resent + # as new prompt each time. Hence the sum. + self.num_prompt_tokens += len(output.prompt_token_ids) + + def _update_num_output_tokens(self, token_ids: Sequence[int]): + self.num_output_tokens += len(token_ids) + def append_output(self, output) -> None: if isinstance(output, RequestOutput): + self._update_num_prompt_tokens(output) output_token_ids = output.outputs[0].token_ids + self._update_num_output_tokens(output_token_ids) self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids: self.parser.process(token_id) @@ -103,10 +128,10 @@ class HarmonyContext(ConversationContext): if recipient is not None: if recipient.startswith("browser."): return await self.call_search_tool( - self.tool_sessions["browser"], last_msg) + self._tool_sessions["browser"], last_msg) elif recipient.startswith("python"): return await self.call_python_tool( - self.tool_sessions["python"], last_msg) + self._tool_sessions["python"], last_msg) raise ValueError("No tool call found") def render_for_completion(self) -> list[int]: @@ -148,6 +173,15 @@ class HarmonyContext(ConversationContext): recipient=Role.ASSISTANT) ] + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + if tool_server: + for tool_name in self.available_tools: + if tool_name not in self._tool_sessions: + self._tool_sessions[ + tool_name] = await exit_stack.enter_async_context( + tool_server.new_session(tool_name)) + class StreamingHarmonyContext(HarmonyContext): @@ -158,6 +192,7 @@ class StreamingHarmonyContext(HarmonyContext): self.parser = get_streamable_parser_for_assistant() self.encoding = get_encoding() self.last_tok = None + self.first_tok_of_message = True @property def messages(self) -> list: @@ -165,8 +200,18 @@ class StreamingHarmonyContext(HarmonyContext): def append_output(self, output) -> None: if isinstance(output, RequestOutput): + # append_output is called for each output token in streaming case, + # so we only want to add the prompt tokens once for each message. + if self.first_tok_of_message: + self._update_num_prompt_tokens(output) + # Reset self.first_tok_of_message if needed: + # if the current token is the last one of the current message + # (finished=True), then the next token processed will mark the + # beginning of a new message + self.first_tok_of_message = output.finished tok = output.outputs[0].token_ids[0] self.parser.process(tok) + self._update_num_output_tokens(output.outputs[0].token_ids) self.last_tok = tok else: # Handle the case of tool output in direct message format diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index efca1472e44cf..bc810f683f4a4 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -329,23 +329,19 @@ def parse_chat_output( token_ids: Sequence[int]) -> tuple[Optional[str], Optional[str], bool]: parser = parse_output_into_messages(token_ids) output_msgs = parser.messages + is_tool_call = False # TODO: update this when tool call is supported if len(output_msgs) == 0: # The generation has stopped during reasoning. - is_tool_call = False reasoning_content = parser.current_content final_content = None elif len(output_msgs) == 1: # The generation has stopped during final message. - is_tool_call = False reasoning_content = output_msgs[0].content[0].text final_content = parser.current_content else: - if len(output_msgs) != 2: - raise ValueError( - "Expected 2 output messages (reasoning and final), " - f"but got {len(output_msgs)}.") - reasoning_msg, final_msg = output_msgs - reasoning_content = reasoning_msg.content[0].text + reasoning_msg = output_msgs[:-1] + final_msg = output_msgs[-1] + reasoning_content = "\n".join( + [msg.content[0].text for msg in reasoning_msg]) final_content = final_msg.content[0].text - is_tool_call = final_msg.recipient is not None return reasoning_content, final_content, is_tool_call diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 9f4dc19fb4ab7..4e852ba594930 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -14,6 +14,8 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.protocol import EngineClient +from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -26,6 +28,11 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], enable_ssl_refresh: bool = False, **uvicorn_kwargs: Any): + """ + Start a FastAPI app using Uvicorn, with support for custom Uvicorn config + options. Supports http header limits via h11_max_incomplete_event_size and + h11_max_header_count. + """ logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -36,7 +43,21 @@ async def serve_http(app: FastAPI, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + # Extract header limit options if present + h11_max_incomplete_event_size = uvicorn_kwargs.pop( + "h11_max_incomplete_event_size", None) + h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None) + + # Set safe defaults if not provided + if h11_max_incomplete_event_size is None: + h11_max_incomplete_event_size = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT + if h11_max_header_count is None: + h11_max_header_count = H11_MAX_HEADER_COUNT_DEFAULT + config = uvicorn.Config(app, **uvicorn_kwargs) + # Set header limits + config.h11_max_incomplete_event_size = h11_max_incomplete_event_size + config.h11_max_header_count = h11_max_header_count config.load() server = uvicorn.Server(config) _add_shutdown_handlers(app, server) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 915f14a29b907..72b6123670b70 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,15 +3,13 @@ import itertools from collections.abc import Sequence -from contextlib import contextmanager -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, - cast, overload) +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast import cloudpickle import torch.nn as nn from pydantic import ValidationError from tqdm.auto import tqdm -from typing_extensions import TypeVar, deprecated +from typing_extensions import TypeVar import vllm.envs as envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, @@ -40,7 +38,6 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam, from vllm.entrypoints.utils import (_validate_truncation_size, log_non_default_args) from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods @@ -54,7 +51,8 @@ from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of +from vllm.utils import Counter, Device, is_list_of +from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: from vllm.v1.metrics.reader import Metric @@ -156,18 +154,6 @@ class LLM: serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead. """ - DEPRECATE_LEGACY: ClassVar[bool] = True - """A flag to toggle whether to deprecate the legacy generate/encode API.""" - - @classmethod - @contextmanager - def deprecate_legacy_api(cls): - cls.DEPRECATE_LEGACY = True - - yield - - cls.DEPRECATE_LEGACY = False - def __init__( self, model: str, @@ -198,7 +184,9 @@ class LLM: override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, - **kwargs, + logits_processors: Optional[list[Union[str, + type[LogitsProcessor]]]] = None, + **kwargs: Any, ) -> None: """LLM constructor.""" @@ -272,6 +260,7 @@ class LLM: mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, + logits_processors=logits_processors, **kwargs, ) @@ -321,99 +310,14 @@ class LLM: return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams() - @overload def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - /, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: str, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - prompt_token_ids: Optional[list[int]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: multi (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: list[str], - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - prompt_token_ids: Optional[list[list[int]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: Optional[str] = None, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - *, - prompt_token_ids: list[int], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: multi (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: Optional[list[str]] = None, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - *, - prompt_token_ids: list[list[int]], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single or multi token ids [pos-only] - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: None, - sampling_params: None, - prompt_token_ids: Union[list[int], list[list[int]]], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @deprecate_kwargs( - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", - ) - def generate( - self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, list[str]]]] = None, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, priority: Optional[list[int]] = None, ) -> list[RequestOutput]: """Generates the completions for the input prompts. @@ -456,15 +360,6 @@ class LLM: "Try passing `--runner generate` to use the model as a " "generative model.") - if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, list[str]]], prompts), - prompt_token_ids=prompt_token_ids, - ) - else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) - if sampling_params is None: # Use default sampling params. sampling_params = self.get_default_sampling_params() @@ -479,10 +374,10 @@ class LLM: # Add any modality specific loras to the corresponding prompts lora_request = self._get_modality_specific_lora_reqs( - parsed_prompts, lora_request) + prompts, lora_request) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -494,7 +389,7 @@ class LLM: return self.engine_class.validate_outputs(outputs, RequestOutput) def _get_modality_specific_lora_reqs( - self, parsed_prompts: Union[PromptType, Sequence[PromptType]], + self, prompts: Union[PromptType, Sequence[PromptType]], lora_request: Optional[Union[list[LoRARequest], LoRARequest]]): # Grab the lora config off the vllm config on the engine, # since this is the same for both v0 & v1. @@ -507,35 +402,33 @@ class LLM: or (lora_config and lora_config.default_mm_loras is None)): return lora_request - if not isinstance(parsed_prompts, Sequence): - parsed_prompts = [parsed_prompts] + if not isinstance(prompts, Sequence): + prompts = [prompts] - optional_loras = ([lora_request] * len(parsed_prompts) + optional_loras = ([lora_request] * len(prompts) if not isinstance(lora_request, Sequence) else lora_request) return [ self._resolve_single_prompt_mm_lora( - parsed_prompt, + prompt, opt_lora_req, lora_config.default_mm_loras, - ) for parsed_prompt, opt_lora_req in zip(parsed_prompts, - optional_loras) + ) for prompt, opt_lora_req in zip(prompts, optional_loras) ] - def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType, + def _resolve_single_prompt_mm_lora(self, prompt: PromptType, lora_request: Optional[LoRARequest], default_mm_loras: Optional[dict[str, str]]): - if (not default_mm_loras or not isinstance(parsed_prompt, dict) - or "multi_modal_data" not in parsed_prompt): + if (not default_mm_loras or not isinstance(prompt, dict) + or "multi_modal_data" not in prompt): return lora_request - parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt) + prompt = cast(Union[TextPrompt, TokensPrompt], prompt) - intersection = set( - parsed_prompt["multi_modal_data"].keys()).intersection( - default_mm_loras.keys()) + intersection = set(prompt["multi_modal_data"].keys()) \ + .intersection(default_mm_loras.keys()) if not intersection: return lora_request if len(intersection) > 1: @@ -630,6 +523,7 @@ class LLM: params: BeamSearchParams, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, use_tqdm: bool = False, + concurrency_limit: Optional[int] = None, ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -640,6 +534,8 @@ class LLM: params: The beam search parameters. lora_request: LoRA request to use for generation, if any. use_tqdm: Whether to use tqdm to display the progress bar. + concurrency_limit: The maximum number of concurrent requests. + If None, the number of concurrent requests is unlimited. """ # TODO: how does beam search work together with length penalty, # frequency, penalty, and stopping criteria, etc.? @@ -658,6 +554,15 @@ class LLM: length_penalty, ) + if use_tqdm and concurrency_limit is not None: + logger.warning( + "Progress bar is not supported when using concurrency_limit. " + "Disabling progress bar.") + use_tqdm = False + + if concurrency_limit is None: + concurrency_limit = len(prompts) + def create_tokens_prompt_from_beam( beam: BeamSearchSequence) -> TokensPrompt: token_prompt_kwargs: TokensPrompt = { @@ -702,73 +607,79 @@ class LLM: **mm_kwargs, ), ) - token_iter = range(max_tokens) - if use_tqdm: - token_iter = tqdm(token_iter, - desc="Beam search", - unit="token", - unit_scale=False) - logger.warning( - "The progress bar shows the upper bound on token steps and " - "may finish early due to stopping conditions. It does not " - "reflect instance-level progress.") + for prompt_start in range(0, len(prompts), concurrency_limit): + instances_batch = instances[prompt_start:prompt_start + + concurrency_limit] - for _ in token_iter: - all_beams: list[BeamSearchSequence] = list( - sum((instance.beams for instance in instances), [])) - pos = [0] + list( - itertools.accumulate( - len(instance.beams) for instance in instances)) - instance_start_and_end: list[tuple[int, int]] = list( - zip(pos[:-1], pos[1:])) + token_iter = range(max_tokens) + if use_tqdm: + token_iter = tqdm(token_iter, + desc="Beam search", + unit="token", + unit_scale=False) + logger.warning( + "The progress bar shows the upper bound on token steps and " + "may finish early due to stopping conditions. It does not " + "reflect instance-level progress.") + for _ in token_iter: + all_beams: list[BeamSearchSequence] = list( + sum((instance.beams for instance in instances_batch), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances_batch)) + instance_start_and_end: list[tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) - if len(all_beams) == 0: - break + if len(all_beams) == 0: + break - # create the corresponding batch entries for prompt & optional lora - prompts_batch, lora_req_batch = zip( - *[(create_tokens_prompt_from_beam(beam), beam.lora_request) - for beam in all_beams]) + # create corresponding batch entries for prompt & optional lora + prompts_batch, lora_req_batch = zip( + *[(create_tokens_prompt_from_beam(beam), beam.lora_request) + for beam in all_beams]) - # only runs for one step - # we don't need to use tqdm here - output = self.generate(prompts_batch, - sampling_params=beam_search_params, - use_tqdm=False, - lora_request=lora_req_batch) + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False, + lora_request=lora_req_batch) - for (start, end), instance in zip(instance_start_and_end, - instances): - instance_new_beams = [] - for i in range(start, end): - current_beam = all_beams[i] - result = output[i] + for (start, end), instance in zip(instance_start_and_end, + instances_batch): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] - if result.outputs[0].logprobs is not None: - # if `result.outputs[0].logprobs` is None, it means - # the sequence is completed because of the max-model-len - # or abortion. we don't need to add it to the new beams. - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - multi_modal_data=current_beam.multi_modal_data, - mm_processor_kwargs=current_beam. - mm_processor_kwargs) + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the + # max-model-len or abortion. we don't need to add + # it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs) - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - instance.completed.append(new_beam) - else: - instance_new_beams.append(new_beam) - sorted_beams = sorted(instance_new_beams, - key=sort_beams_key, - reverse=True) - instance.beams = sorted_beams[:beam_width] + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=sort_beams_key, + reverse=True) + instance.beams = sorted_beams[:beam_width] outputs = [] for instance in instances: @@ -804,8 +715,8 @@ class LLM: Generate responses for a chat conversation. The chat conversation is converted into a text prompt using the - tokenizer and calls the [generate][] method to generate the - responses. + tokenizer and calls the [generate][vllm.LLM.generate] method to generate + the responses. Multi-modal inputs can be passed in the same way you would pass them to the OpenAI API. @@ -929,11 +840,9 @@ class LLM: lora_request=lora_request, ) - @overload def encode( self, prompts: Union[PromptType, Sequence[PromptType]], - /, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, @@ -942,107 +851,6 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, pooling_task: PoolingTask = "encode", tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: str, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[list[int]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: multi (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: list[str], - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[list[list[int]]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: Optional[str] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - *, - prompt_token_ids: list[int], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: multi (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: Optional[list[str]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - *, - prompt_token_ids: list[list[int]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single or multi token ids [pos-only] - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: None, - pooling_params: None, - prompt_token_ids: Union[list[int], list[list[int]]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @deprecate_kwargs( - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", - ) - def encode( - self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, list[str]]]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: Optional[PoolingTask] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -1104,15 +912,6 @@ class LLM: raise ValueError( f"pooling_task must be one of {self.supported_tasks}.") - if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, list[str]]], prompts), - prompt_token_ids=prompt_token_ids, - ) - else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) - if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() @@ -1130,7 +929,7 @@ class LLM: tokenization_kwargs) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -1144,7 +943,6 @@ class LLM: def embed( self, prompts: Union[PromptType, Sequence[PromptType]], - /, *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, @@ -1194,7 +992,6 @@ class LLM: def classify( self, prompts: Union[PromptType, Sequence[PromptType]], - /, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, pooling_params: Optional[Union[PoolingParams, @@ -1344,7 +1141,7 @@ class LLM: _validate_truncation_size(model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) - parsed_prompts = [] + prompts = list[PromptType]() input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] @@ -1368,10 +1165,10 @@ class LLM: else: pooling_params_list.append(pooling_params) - parsed_prompts.append(engine_prompt) + prompts.append(engine_prompt) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=pooling_params_list, use_tqdm=use_tqdm, lora_request=lora_request, @@ -1555,8 +1352,8 @@ class LLM: def wake_up(self, tags: Optional[list[str]] = None): """ - Wake up the engine from sleep mode. See the [sleep][] method - for more details. + Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep] + method for more details. Args: tags: An optional list of tags to reallocate the engine memory @@ -1581,48 +1378,6 @@ class LLM: assert isinstance(self.llm_engine, V1LLMEngine) return self.llm_engine.get_metrics() - # LEGACY - def _convert_v1_inputs( - self, - prompts: Optional[Union[str, list[str]]], - prompt_token_ids: Optional[Union[list[int], list[list[int]]]], - ): - # skip_tokenizer_init is now checked in engine - - if prompts is None and prompt_token_ids is None: - raise ValueError( - "Either prompts or prompt_token_ids must be provided.") - if prompts is not None and prompt_token_ids is not None \ - and len(prompts) != len(prompt_token_ids): - raise ValueError( - "The lengths of prompts and prompt_token_ids must be the same." - ) - - if prompts is not None: - prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] - if prompt_token_ids is not None: - prompt_token_ids = [ - p["content"] for p in parse_and_batch_prompt(prompt_token_ids) - ] - if prompts is not None: - num_requests = len(prompts) - elif prompt_token_ids is not None: - num_requests = len(prompt_token_ids) - parsed_prompts: list[PromptType] = [] - for i in range(num_requests): - item: PromptType - - if prompts is not None: - item = TextPrompt(prompt=prompts[i]) - elif prompt_token_ids is not None: - item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) - else: - raise AssertionError - - parsed_prompts.append(item) - - return parsed_prompts - def _validate_and_add_requests( self, prompts: Union[PromptType, Sequence[PromptType]], diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index af86835a497d4..9a2470649c8d2 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -600,8 +600,11 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Responses API") - - generator = await handler.create_responses(request, raw_request) + try: + generator = await handler.create_responses(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -618,7 +621,11 @@ async def retrieve_responses(response_id: str, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Responses API") - response = await handler.retrieve_responses(response_id) + try: + response = await handler.retrieve_responses(response_id) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), @@ -633,7 +640,11 @@ async def cancel_responses(response_id: str, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Responses API") - response = await handler.cancel_responses(response_id) + try: + response = await handler.cancel_responses(response_id) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), @@ -667,9 +678,11 @@ async def create_chat_completion(request: ChatCompletionRequest, if handler is None: return base(raw_request).create_error_response( message="The model does not support Chat Completions API") - - generator = await handler.create_chat_completion(request, raw_request) - + try: + generator = await handler.create_chat_completion(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -742,7 +755,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Embeddings API") - generator = await handler.create_embedding(request, raw_request) + try: + generator = await handler.create_embedding(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -770,8 +787,11 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Pooling API") - - generator = await handler.create_pooling(request, raw_request) + try: + generator = await handler.create_pooling(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -791,7 +811,11 @@ async def create_classify(request: ClassificationRequest, return base(raw_request).create_error_response( message="The model does not support Classification API") - generator = await handler.create_classify(request, raw_request) + try: + generator = await handler.create_classify(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -820,7 +844,11 @@ async def create_score(request: ScoreRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Score API") - generator = await handler.create_score(request, raw_request) + try: + generator = await handler.create_score(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -878,8 +906,12 @@ async def create_transcriptions(raw_request: Request, message="The model does not support Transcriptions API") audio_data = await request.file.read() - generator = await handler.create_transcription(audio_data, request, - raw_request) + try: + generator = await handler.create_transcription(audio_data, request, + raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -919,8 +951,12 @@ async def create_translations(request: Annotated[TranslationRequest, message="The model does not support Translations API") audio_data = await request.file.read() - generator = await handler.create_translation(audio_data, request, - raw_request) + try: + generator = await handler.create_translation(audio_data, request, + raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -949,7 +985,11 @@ async def do_rerank(request: RerankRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Rerank (Score) API") - generator = await handler.do_rerank(request, raw_request) + try: + generator = await handler.do_rerank(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -1044,6 +1084,34 @@ if envs.VLLM_SERVER_DEV_MODE: is_sleeping = await engine_client(raw_request).is_sleeping() return JSONResponse(content={"is_sleeping": is_sleeping}) + @router.post("/collective_rpc") + async def collective_rpc(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}") from e + method = body.get("method") + if method is None: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'method' in request body") + # For security reason, only serialized string args/kwargs are passed. + # User-defined `method` is responsible for deseralization if needed. + args: list[str] = body.get("args", []) + kwargs: dict[str, str] = body.get("kwargs", {}) + timeout: Optional[float] = body.get("timeout") + results = await engine_client(raw_request).collective_rpc( + method=method, timeout=timeout, args=tuple(args), kwargs=kwargs) + if results is None: + return Response(status_code=200) + response: list[Any] = [] + for result in results: + if result is None or isinstance(result, (dict, list)): + response.append(result) + else: + response.append(str(result)) + return JSONResponse(content={"results": response}) + @router.post("/scale_elastic_ep", dependencies=[Depends(validate_json_request)], @@ -1680,6 +1748,8 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_chat = OpenAIServingChat( engine_client, @@ -1697,6 +1767,8 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, @@ -1706,6 +1778,7 @@ async def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, @@ -1714,6 +1787,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) if "encode" in supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, @@ -1722,12 +1796,14 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) if "embed" in supported_tasks else None state.openai_serving_classification = ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "classify" in supported_tasks else None enable_serving_reranking = ("classify" in supported_tasks and getattr( @@ -1737,6 +1813,7 @@ async def init_app_state( model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if ("embed" in supported_tasks or enable_serving_reranking) else None state.openai_serving_tokenization = OpenAIServingTokenization( @@ -1746,18 +1823,21 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) state.openai_serving_transcription = OpenAIServingTranscription( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "transcription" in supported_tasks else None state.openai_serving_translation = OpenAIServingTranslation( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "transcription" in supported_tasks else None state.enable_server_load_tracking = args.enable_server_load_tracking @@ -1894,6 +1974,8 @@ async def run_server_worker(listen_address, ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, + h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, + h11_max_header_count=args.h11_max_header_count, **uvicorn_kwargs, ) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index e15f65b43082c..d0b5d013eb9e5 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -20,6 +20,8 @@ from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) +from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger @@ -172,6 +174,14 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" enable_log_outputs: bool = False """If set to True, enable logging of model outputs (generations) in addition to the input logging that is enabled by default.""" + h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT + """Maximum size (bytes) of an incomplete HTTP event (header or body) for + h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB).""" + h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT + """Maximum number of HTTP headers allowed in a request for h11 parser. + Helps mitigate header abuse. Default: 256.""" + log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE + """If set to True, log the stack trace of error responses""" @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 543701ed144ee..5cb41bd93d4bc 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -20,7 +20,15 @@ from openai.types.chat.chat_completion_message import ( from openai.types.responses import (ResponseFunctionToolCall, ResponseInputItemParam, ResponseOutputItem, ResponsePrompt, ResponseReasoningItem, - ResponseStatus, ResponseTextConfig) + ResponseStatus) + +# Backward compatibility for OpenAI client versions +try: # For older openai versions (< 1.100.0) + from openai.types.responses import ResponseTextConfig +except ImportError: # For newer openai versions (>= 1.100.0) + from openai.types.responses import (ResponseFormatTextConfig as + ResponseTextConfig) + from openai.types.responses.response import ToolChoice from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning @@ -30,7 +38,7 @@ from typing_extensions import TypeAlias from vllm import envs from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - random_tool_call_id) + make_tool_call_id) from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam) from vllm.logger import init_logger @@ -349,13 +357,22 @@ class ResponsesRequest(OpenAIBaseModel): temperature=temperature, top_p=top_p, max_tokens=max_tokens, - logprobs=self.top_logprobs, + logprobs=self.top_logprobs + if self.is_include_output_logprobs() else None, stop_token_ids=stop_token_ids, output_kind=(RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY), guided_decoding=guided_decoding, ) + def is_include_output_logprobs(self) -> bool: + """Check if the request includes output logprobs.""" + if self.include is None: + return False + return isinstance( + self.include, + list) and "message.output_text.logprobs" in self.include + @model_validator(mode="before") def validate_background(cls, data): if not data.get("background"): @@ -568,6 +585,14 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + return_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens.")) cache_salt: Optional[str] = Field( default=None, description=( @@ -1054,6 +1079,14 @@ class CompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + return_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens.")) cache_salt: Optional[str] = Field( default=None, @@ -1472,7 +1505,9 @@ class CompletionResponseChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + token_ids: Optional[list[int]] = None # For response prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + prompt_token_ids: Optional[list[int]] = None # For prompt class CompletionResponse(OpenAIBaseModel): @@ -1503,6 +1538,10 @@ class CompletionResponseStreamChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + # not part of the OpenAI spec but for tracing the tokens + # prompt tokens is put into choice to align with CompletionResponseChoice + prompt_token_ids: Optional[list[int]] = None + token_ids: Optional[list[int]] = None class CompletionStreamResponse(OpenAIBaseModel): @@ -1604,7 +1643,7 @@ class FunctionCall(OpenAIBaseModel): class ToolCall(OpenAIBaseModel): - id: str = Field(default_factory=random_tool_call_id) + id: str = Field(default_factory=make_tool_call_id) type: Literal["function"] = "function" function: FunctionCall @@ -1672,6 +1711,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): finish_reason: Optional[str] = "stop" # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None + # not part of the OpenAI spec but is useful for tracing the tokens + # in agent scenarios + token_ids: Optional[list[int]] = None class ChatCompletionResponse(OpenAIBaseModel): @@ -1687,6 +1729,7 @@ class ChatCompletionResponse(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + prompt_token_ids: Optional[list[int]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters.") @@ -1704,6 +1747,8 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): logprobs: Optional[ChatCompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None + # not part of the OpenAI spec but for tracing the tokens + token_ids: Optional[list[int]] = None class ChatCompletionStreamResponse(OpenAIBaseModel): @@ -1713,6 +1758,8 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): model: str choices: list[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + # not part of the OpenAI spec but for tracing the tokens + prompt_token_ids: Optional[list[int]] = None class TranscriptionResponseStreamChoice(OpenAIBaseModel): @@ -1770,7 +1817,7 @@ class ResponsesResponse(OpenAIBaseModel): service_tier: Literal["auto", "default", "flex", "scale", "priority"] status: ResponseStatus text: Optional[ResponseTextConfig] = None - top_logprobs: int + top_logprobs: Optional[int] = None truncation: Literal["auto", "disabled"] usage: Optional[ResponseUsage] = None user: Optional[str] = None @@ -2185,9 +2232,15 @@ class TranscriptionRequest(OpenAIBaseModel): # Transcription response objects +class TranscriptionUsageAudio(OpenAIBaseModel): + type: Literal["duration"] = "duration" + seconds: int + + class TranscriptionResponse(OpenAIBaseModel): text: str """The transcribed text.""" + usage: TranscriptionUsageAudio class TranscriptionWord(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b4231c6d10c4e..1c0ffdfb91897 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -19,7 +19,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, ConversationMessage, - random_tool_call_id) + get_history_tool_calls_cnt, + make_tool_call_id) from vllm.entrypoints.harmony_utils import ( get_developer_message, get_stop_tokens_for_assistant_actions, get_streamable_parser_for_assistant, get_system_message, parse_chat_input, @@ -50,6 +51,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, truncate_tool_call_ids, validate_request_params) +from vllm.utils import as_list logger = init_logger(__name__) @@ -74,13 +76,15 @@ class OpenAIServingChat(OpenAIServing): enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack) self.response_role = response_role self.chat_template = chat_template @@ -132,6 +136,10 @@ class OpenAIServingChat(OpenAIServing): source = "model" if source == "auto" else source logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + if self.model_config.hf_config.model_type == 'kimi_k2': + self.tool_call_id_type = 'kimi_k2' + else: + self.tool_call_id_type = 'random' self.use_harmony = model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: @@ -378,6 +386,7 @@ class OpenAIServingChat(OpenAIServing): current_text: Optional[str], delta_text: str, function_name_returned: bool, + tool_call_idx: Optional[int] = None ) -> tuple[Optional[DeltaMessage], bool]: if current_text is None or current_text == "": # if the current text is empty, we cannot parse it @@ -423,8 +432,12 @@ class OpenAIServingChat(OpenAIServing): current_tool_call = obj[-2] function_name_returned = True + tool_call_id = make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=current_tool_call["name"], + idx=tool_call_idx) delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(id=random_tool_call_id(), + DeltaToolCall(id=tool_call_id, function=DeltaFunctionCall( name=current_tool_call["name"], arguments=arguments), @@ -490,6 +503,10 @@ class OpenAIServingChat(OpenAIServing): all_previous_token_ids: Optional[list[list[int]]] function_name_returned = [False] * num_choices + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 # Always track previous_texts for comprehensive output logging previous_texts = [""] * num_choices @@ -567,12 +584,17 @@ class OpenAIServingChat(OpenAIServing): ), logprobs=None, finish_reason=None) + + # return prompt_token_ids at the first chunk ever chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + prompt_token_ids=(res.prompt_token_ids + if request.return_token_ids else + None)) # if continuous usage stats are requested, add it if include_continuous_usage: @@ -643,9 +665,9 @@ class OpenAIServingChat(OpenAIServing): harmony_parser = harmony_parsers[i] for token_id in output.token_ids: harmony_parser.process(token_id) - # FIXME(woosuk): Support function calling - is_final = harmony_parser.current_channel == "final" - if not (request.include_reasoning or is_final): + is_reasoning = \ + harmony_parser.current_channel == "analysis" + if not request.include_reasoning and is_reasoning: # Skip the reasoning content. continue delta_text = harmony_parser.last_content_delta or "" @@ -667,20 +689,19 @@ class OpenAIServingChat(OpenAIServing): previous_text = previous_texts[i] previous_token_ids = all_previous_token_ids[i] current_text = previous_text + delta_text - # avoid the None + list error. if previous_token_ids: - current_token_ids = previous_token_ids + list( + current_token_ids = previous_token_ids + as_list( output.token_ids) else: - current_token_ids = list(output.token_ids) + current_token_ids = as_list(output.token_ids) if self.use_harmony: - if is_final: - delta_message = DeltaMessage(content=delta_text) - else: + if is_reasoning: delta_message = DeltaMessage( reasoning_content=delta_text) + else: + delta_message = DeltaMessage(content=delta_text) # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] @@ -703,11 +724,10 @@ class OpenAIServingChat(OpenAIServing): # set reasoning status to end. # Only keep 'content', remove 'reasoning_content'. if reasoning_parser.is_reasoning_end( - list(output.token_ids)) or \ - (res.prompt_token_ids and - reasoning_parser.is_reasoning_end( - list(res.prompt_token_ids) - )): + as_list(output.token_ids)) or ( + res.prompt_token_ids + and reasoning_parser.is_reasoning_end( + res.prompt_token_ids)): reasoning_end_arr[i] = True if delta_message and delta_message.content: # This need to be added to next `delta_text` @@ -728,7 +748,7 @@ class OpenAIServingChat(OpenAIServing): index=i) else: delta_tool_call = DeltaToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, @@ -759,7 +779,11 @@ class OpenAIServingChat(OpenAIServing): previous_text=previous_text, current_text=content, delta_text=delta_text, - function_name_returned=fn_name_returned)) + function_name_returned=fn_name_returned, + tool_call_idx=history_tool_call_cnt)) + if (delta_message and delta_message.tool_calls and + delta_message.tool_calls[0].id is not None): + history_tool_call_cnt += 1 # update the previous values for the next iteration previous_texts[i] = current_text @@ -771,6 +795,7 @@ class OpenAIServingChat(OpenAIServing): assert reasoning_parser is not None assert added_content_delta_arr is not None assert reasoning_end_arr is not None + output_token_ids = as_list(output.token_ids) if not reasoning_end_arr[i]: delta_message = ( reasoning_parser. @@ -780,7 +805,7 @@ class OpenAIServingChat(OpenAIServing): delta_text, previous_token_ids, current_token_ids, - output.token_ids, + output_token_ids, )) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, @@ -789,9 +814,9 @@ class OpenAIServingChat(OpenAIServing): # to 'reasoning_content'. if res.prompt_token_ids and \ reasoning_parser.is_reasoning_end( - list(res.prompt_token_ids)): + res.prompt_token_ids): reasoning_end_arr[i] = True - current_token_ids = list(output.token_ids) + current_token_ids = output_token_ids if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None @@ -802,11 +827,11 @@ class OpenAIServingChat(OpenAIServing): # Remove the text and token ids related # to 'reasoning_content'. if reasoning_parser.is_reasoning_end( - list(output.token_ids)): + output_token_ids): reasoning_end_arr[i] = True current_token_ids = \ reasoning_parser.extract_content_ids( - list(output.token_ids)) + output_token_ids) if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None @@ -815,7 +840,7 @@ class OpenAIServingChat(OpenAIServing): # handle tool calls only after reasoning is done, else: - delta_token_ids = list(output.token_ids) + delta_token_ids = output_token_ids # First time to tool call, # add the remaining text and token ids # to delta from previous @@ -864,7 +889,8 @@ class OpenAIServingChat(OpenAIServing): delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if tool_choice_auto or self.reasoning_parser: + if ((tool_choice_auto or self.reasoning_parser) + and not self.use_harmony): assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -899,7 +925,7 @@ class OpenAIServingChat(OpenAIServing): self.request_logger.log_outputs( request_id=request_id, outputs=delta_content, - output_token_ids=list(output.token_ids), + output_token_ids=as_list(output.token_ids), finish_reason=output.finish_reason, is_streaming=True, delta=True, @@ -911,7 +937,9 @@ class OpenAIServingChat(OpenAIServing): index=i, delta=delta_message, logprobs=logprobs, - finish_reason=None) + finish_reason=None, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None)) # if the model is finished generating else: @@ -972,7 +1000,9 @@ class OpenAIServingChat(OpenAIServing): logprobs=logprobs, finish_reason=output.finish_reason if not auto_tools_called else "tool_calls", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None)) finish_reason_sent[i] = True @@ -1079,6 +1109,10 @@ class OpenAIServingChat(OpenAIServing): assert final_res is not None choices: list[ChatCompletionResponseChoice] = [] + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 role = self.get_chat_request_role(request) for output in final_res.outputs: @@ -1184,17 +1218,26 @@ class OpenAIServingChat(OpenAIServing): assert content is not None tool_calls = TypeAdapter( list[FunctionDefinition]).validate_json(content) + tool_call_ids = [] + for tool_call in tool_calls: + tool_call_ids.append( + make_tool_call_id(id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt)) + history_tool_call_cnt += 1 message = ChatMessage( role=role, content="", - reasoning_content=reasoning_content, tool_calls=[ - tool_call_class(function=FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, - ensure_ascii=False))) - for tool_call in tool_calls - ]) + tool_call_class(id=tool_call_ids[i], + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps( + tool_call.parameters, + ensure_ascii=False))) + for i, tool_call in enumerate(tool_calls) + ], + reasoning_content=reasoning_content) # if the request doesn't use tool choice # OR specifies to not use a tool @@ -1238,7 +1281,6 @@ class OpenAIServingChat(OpenAIServing): if (tool_call_info.content and len(tool_call_info.content) > 0): ret_content = tool_call_info.content - message = ChatMessage(role=role, reasoning_content=reasoning_content, content=ret_content) @@ -1259,7 +1301,10 @@ class OpenAIServingChat(OpenAIServing): logprobs=logprobs, finish_reason="tool_calls" if auto_tools_called else output.finish_reason if output.finish_reason else "stop", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None), + ) choices.append(choice_data) @@ -1300,6 +1345,8 @@ class OpenAIServingChat(OpenAIServing): choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + prompt_token_ids=(final_res.prompt_token_ids + if request.return_token_ids else None), kv_transfer_params=final_res.kv_transfer_params, ) @@ -1312,12 +1359,11 @@ class OpenAIServingChat(OpenAIServing): elif choice.message.tool_calls: # For tool calls, log the function name and arguments tool_call_descriptions = [] - for tool_call in choice.message.tool_calls: - if hasattr(tool_call.function, "name") and hasattr( - tool_call.function, "arguments"): + for tc in choice.message.tool_calls: + if hasattr(tc.function, "name") and hasattr( + tc.function, "arguments"): tool_call_descriptions.append( - f"{tool_call.function.name}({tool_call.function.arguments})" - ) + f"{tc.function.name}({tc.function.arguments})") tool_calls_str = ", ".join(tool_call_descriptions) output_text = f"[tool_calls: {tool_calls_str}]" @@ -1468,4 +1514,9 @@ class OpenAIServingChat(OpenAIServing): # Render prompt token ids. prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return messages, [prompt_token_ids], [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 377f7f6847179..1d510d0b60a2d 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -129,12 +129,14 @@ class ServingClassification(ClassificationMixin): models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, + log_error_stack=log_error_stack, ) async def create_classify( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 22c6b6250394c..b81fd63ece7a4 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -42,7 +42,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import merge_async_iterators +from vllm.utils import as_list, merge_async_iterators logger = init_logger(__name__) @@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServing): return_tokens_as_token_ids: bool = False, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__( engine_client=engine_client, @@ -67,6 +68,7 @@ class OpenAIServingCompletion(OpenAIServing): request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = ( @@ -365,6 +367,11 @@ class OpenAIServingCompletion(OpenAIServing): for output in res.outputs: i = output.index + prompt_idx * num_choices + # Useful when request.return_token_ids is True + # Returning prompt token IDs shares the same logic + # with the echo implementation. + prompt_token_ids_to_return: Optional[list[int]] = None + assert request.max_tokens is not None if request.echo and not has_echoed[i]: assert prompt_token_ids is not None @@ -385,6 +392,7 @@ class OpenAIServingCompletion(OpenAIServing): *(prompt_logprobs or []), *(output.logprobs or []), ] + prompt_token_ids_to_return = prompt_token_ids has_echoed[i] = True else: # return just the delta @@ -392,6 +400,12 @@ class OpenAIServingCompletion(OpenAIServing): delta_token_ids = output.token_ids out_logprobs = output.logprobs + # has_echoed[i] is reused here to indicate whether + # we have already returned the prompt token IDs. + if not has_echoed[i]: + prompt_token_ids_to_return = prompt_token_ids + has_echoed[i] = True + if (not delta_text and not delta_token_ids and not previous_num_tokens[i]): # Chunked prefill case, don't return empty chunks @@ -428,6 +442,9 @@ class OpenAIServingCompletion(OpenAIServing): logprobs=logprobs, finish_reason=finish_reason, stop_reason=stop_reason, + prompt_token_ids=prompt_token_ids_to_return, + token_ids=(as_list(output.token_ids) if + request.return_token_ids else None), ) ], ) @@ -548,6 +565,10 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason=output.finish_reason, stop_reason=output.stop_reason, prompt_logprobs=final_res.prompt_logprobs, + prompt_token_ids=(prompt_token_ids + if request.return_token_ids else None), + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None), ) choices.append(choice_data) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 9dcad8e391c68..45c1932f1873c 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -593,11 +593,13 @@ class OpenAIServingEmbedding(EmbeddingMixin): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d6f92a63301e8..a97935e109ef2 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,6 +5,7 @@ import io import json import sys import time +import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus @@ -205,6 +206,7 @@ class OpenAIServing: request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__() @@ -222,6 +224,7 @@ class OpenAIServing: self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} + self.log_error_stack = log_error_stack def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ @@ -412,6 +415,12 @@ class OpenAIServing: message: str, err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + if self.log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() return ErrorResponse(error=ErrorInfo( message=message, type=err_type, code=status_code.value)) @@ -1006,8 +1015,8 @@ class OpenAIServing: # OPTIMIZATION priority = orig_priority - 1 + @staticmethod def _load_prompt_embeds( - self, prompt_embeds: Optional[Union[bytes, list[bytes]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None ) -> list[EmbedsPrompt]: @@ -1015,12 +1024,14 @@ class OpenAIServing: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: tensor = torch.load(io.BytesIO( pybase64.b64decode(embed, validate=True)), - weights_only=True) + weights_only=True, + map_location=torch.device("cpu")) assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( torch.float32, torch.bfloat16, torch.float16, ) + tensor = tensor.to_dense() if tensor.dim() > 2: tensor = tensor.squeeze(0) assert tensor.dim() == 2 diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 38745d001ade6..e8cb1aed84596 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -58,11 +58,13 @@ class OpenAIServingPooling(OpenAIServing): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 86c16df40e693..899cb07b2b37d 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -4,11 +4,11 @@ import asyncio import json import time -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncGenerator, AsyncIterator, Sequence from contextlib import AsyncExitStack from copy import copy from http import HTTPStatus -from typing import Any, Callable, Final, Optional, Union +from typing import Callable, Final, Optional, Union import jinja2 import openai.types.responses as openai_responses_types @@ -25,6 +25,8 @@ from openai.types.responses import (ResponseCreatedEvent, ResponseReasoningItem, ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent) +from openai.types.responses.response_output_text import (Logprob, + LogprobTopLogprob) # yapf: enable from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent) @@ -59,6 +61,8 @@ from vllm.logger import init_logger from vllm.outputs import CompletionOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob as SampleLogprob +from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -84,6 +88,7 @@ class OpenAIServingResponses(OpenAIServing): enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, @@ -92,6 +97,7 @@ class OpenAIServingResponses(OpenAIServing): request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.chat_template = chat_template @@ -201,6 +207,12 @@ class OpenAIServingResponses(OpenAIServing): # (i.e., their request's `store=True` just because it's the default # value). request.store = False + if self.use_harmony and request.is_include_output_logprobs(): + return self.create_error_response( + err_type="invalid_request_error", + message="logprobs are not supported with gpt-oss models", + status_code=HTTPStatus.BAD_REQUEST, + ) # Handle the previous response ID. prev_response_id = request.previous_response_id @@ -238,10 +250,10 @@ class OpenAIServingResponses(OpenAIServing): raw_request.state.request_metadata = request_metadata if self.tool_server is not None and isinstance( - self.tool_server, MCPToolServer - ) and (request.background or request.stream) and request.tools and any( - tool.type in ["web_search_preview", "code_interpreter"] - for tool in request.tools): + self.tool_server, + MCPToolServer) and request.stream and request.tools and any( + tool.type in ["web_search_preview", "code_interpreter"] + for tool in request.tools): return self.create_error_response( "MCP tool server is not supported in background mode and " "streaming mode") @@ -255,103 +267,70 @@ class OpenAIServingResponses(OpenAIServing): builtin_tool_list.append("browser") if self.tool_server.has_tool("python"): builtin_tool_list.append("python") - async with AsyncExitStack() as exit_stack: - try: - if self.tool_server is not None: - # TODO: initialize tool sessions lazily when the session - # is actually used. - tool_session_ctxs: dict[str, Any] = { - tool_name: - exit_stack.enter_async_context( - self.tool_server.new_session(tool_name)) - for tool_name in builtin_tool_list - } - tool_sessions = {} - for tool_name in builtin_tool_list: - tool_sessions[tool_name] = ( - await tool_session_ctxs[tool_name]) - else: - assert len(builtin_tool_list) == 0 - tool_sessions = {} - for i, engine_prompt in enumerate(engine_prompts): - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) - trace_headers = (None if raw_request is None else await - self._get_trace_headers( - raw_request.headers)) + if self.tool_server is not None: + available_tools = builtin_tool_list + else: + assert len(builtin_tool_list) == 0 + available_tools = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) - context: ConversationContext - if self.use_harmony: - if request.stream: - context = StreamingHarmonyContext( - messages, tool_sessions) - else: - context = HarmonyContext(messages, tool_sessions) + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext( + messages, available_tools) else: - context = SimpleContext() - generator = self._generate_with_builtin_tools( - request_id=request.request_id, - request_prompt=request_prompts[i], - engine_prompt=engine_prompt, - sampling_params=sampling_params, - context=context, - lora_request=lora_request, - priority=request.priority, - trace_headers=trace_headers, - ) - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - assert len(generators) == 1 - result_generator, = generators - - # Store the input messages. - if request.store: - self.msg_store[request.request_id] = messages - - if request.background: - created_time = int(time.time()) - response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=[], - status="queued", - usage=None, + context = HarmonyContext(messages, available_tools) + else: + context = SimpleContext() + generator = self._generate_with_builtin_tools( + request_id=request.request_id, + request_prompt=request_prompts[i], + engine_prompt=engine_prompt, + sampling_params=sampling_params, + context=context, + lora_request=lora_request, + priority=request.priority, + trace_headers=trace_headers, ) - async with self.response_store_lock: - self.response_store[response.id] = response + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) - # Run the request in the background. - task = asyncio.create_task( - self._run_background_request( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - created_time, - ), - name=f"create_{response.id}", - ) + assert len(generators) == 1 + result_generator, = generators - # For cleanup. - response_id = response.id - self.background_tasks[response_id] = task - task.add_done_callback( - lambda _: self.background_tasks.pop(response_id, None)) - return response + # Store the input messages. + if request.store: + self.msg_store[request.request_id] = messages - if request.stream: - return self.responses_stream_generator( + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, + ) + async with self.response_store_lock: + self.response_store[response.id] = response + + # Run the request in the background. + task = asyncio.create_task( + self._run_background_request( request, sampling_params, result_generator, @@ -359,21 +338,41 @@ class OpenAIServingResponses(OpenAIServing): model_name, tokenizer, request_metadata, - ) + created_time, + ), + name=f"create_{response.id}", + ) - try: - return await self.responses_full_generator( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - ) - except Exception as e: - return self.create_error_response(str(e)) - return self.create_error_response("Should not reach here") + # For cleanup. + response_id = response.id + self.background_tasks[response_id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response_id, None)) + return response + + if request.stream: + return self.responses_stream_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + + try: + return await self.responses_full_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + except Exception as e: + return self.create_error_response(str(e)) async def _make_request( self, @@ -408,6 +407,11 @@ class OpenAIServingResponses(OpenAIServing): request, prev_response) prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return messages, [prompt_token_ids], [engine_prompt] async def responses_full_generator( @@ -424,14 +428,16 @@ class OpenAIServingResponses(OpenAIServing): if created_time is None: created_time = int(time.time()) - try: - async for _ in result_generator: - pass - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + async with AsyncExitStack() as exit_stack: + try: + await context.init_tool_sessions(self.tool_server, exit_stack) + async for _ in result_generator: + pass + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) if self.use_harmony: assert isinstance(context, HarmonyContext) @@ -486,6 +492,51 @@ class OpenAIServingResponses(OpenAIServing): self.response_store[response.id] = response return response + def _topk_logprobs(self, logprobs: dict[int, + SampleLogprob], top_logprobs: int, + tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]: + """Returns the top-k logprobs from the logprobs dictionary.""" + out = [] + for i, (token_id, _logprob) in enumerate(logprobs.items()): + if i >= top_logprobs: + break + text = _logprob.decoded_token if _logprob.decoded_token \ + is not None else tokenizer.decode([token_id]) + out.append( + LogprobTopLogprob( + token=text, + logprob=max(_logprob.logprob, -9999.0), + bytes=list(text.encode("utf-8", errors="replace")), + )) + return out + + def _create_response_logprobs( + self, + token_ids: Sequence[int], + logprobs: Optional[SampleLogprobs], + tokenizer: AnyTokenizer, + top_logprobs: Optional[int] = None) -> list[Logprob]: + assert logprobs is not None, "logprobs must be provided" + assert len(token_ids) == len(logprobs), ( + "token_ids and logprobs.token_ids must have the same length") + out = [] + for i, token_id in enumerate(token_ids): + logprob = logprobs[i] + token_logprob = logprob[token_id] + text = token_logprob.decoded_token if token_logprob.decoded_token \ + is not None else tokenizer.decode([token_id]) + out.append( + Logprob( + token=text, + logprob=max(token_logprob.logprob, -9999.0), + bytes=list(text.encode("utf-8", errors="replace")), + top_logprobs=self._topk_logprobs(logprob, + top_logprobs=top_logprobs, + tokenizer=tokenizer) + if top_logprobs else [], + )) + return out + def _make_response_output_items( self, request: ResponsesRequest, @@ -542,7 +593,12 @@ class OpenAIServingResponses(OpenAIServing): text=content, annotations=[], # TODO type="output_text", - logprobs=None, # TODO + logprobs=self._create_response_logprobs( + token_ids=final_output.token_ids, + logprobs=final_output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) if request.is_include_output_logprobs() else None, ) message = ResponseOutputMessage( id=f"msg_{random_uuid()}", @@ -773,7 +829,7 @@ class OpenAIServingResponses(OpenAIServing): status_code=HTTPStatus.BAD_REQUEST, ) - async def responses_stream_generator( + async def _process_streaming_events( self, request: ResponsesRequest, sampling_params: SamplingParams, @@ -782,18 +838,8 @@ class OpenAIServingResponses(OpenAIServing): model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - created_time: Optional[int] = None, + created_time: int, ) -> AsyncGenerator[str, None]: - # TODO: - # 1. Handle disconnect - - if not isinstance(context, StreamingHarmonyContext): - raise NotImplementedError( - "Streaming is not supported for responses API without Harmony." - ) - - created_time = created_time or int(time.time()) - sequence_number = 0 def _send_event(event: BaseModel): @@ -1004,7 +1050,48 @@ class OpenAIServingResponses(OpenAIServing): delta=ctx.parser.last_content_delta, sequence_number=-1, )) - + # built-in tools will be triggered on the analysis channel + # However, occasionally built-in tools will + # still be output to commentary. + elif (ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) and ctx.parser.current_recipient == "python": + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=None, + container_id="auto", + outputs=None, + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallInProgressEvent( + type= + "response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCodeDeltaEvent( + type="response.code_interpreter_call_code.delta", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + )) if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] if (self.tool_server is not None @@ -1100,30 +1187,6 @@ class OpenAIServingResponses(OpenAIServing): and self.tool_server.has_tool("python") and previous_item.recipient is not None and previous_item.recipient.startswith("python")): - yield _send_event( - openai_responses_types.ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( - type="code_interpreter_call", - id=current_item_id, - code="", - container_id="auto", - outputs=[], - status="in_progress", - ), - )) - yield _send_event( - openai_responses_types. - ResponseCodeInterpreterCallInProgressEvent( - type="response.code_interpreter_call.in_progress", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - )) - # TODO: do we need to add delta event here? yield _send_event( openai_responses_types. ResponseCodeInterpreterCallCodeDoneEvent( @@ -1131,7 +1194,8 @@ class OpenAIServingResponses(OpenAIServing): sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - code=previous_item.content[0].text)) + code=previous_item.content[0].text, + )) yield _send_event( openai_responses_types. ResponseCodeInterpreterCallInterpretingEvent( @@ -1187,3 +1251,31 @@ class OpenAIServingResponses(OpenAIServing): sequence_number=-1, response=final_response.model_dump(), )) + + async def responses_stream_generator( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> AsyncGenerator[str, None]: + # TODO: + # 1. Handle disconnect + + if not isinstance(context, StreamingHarmonyContext): + raise NotImplementedError( + "Streaming is not supported for responses API without Harmony." + ) + + created_time = created_time or int(time.time()) + + async with AsyncExitStack() as exit_stack: + await context.init_tool_sessions(self.tool_server, exit_stack) + async for event_data in self._process_streaming_events( + request, sampling_params, result_generator, context, + model_name, tokenizer, request_metadata, created_time): + yield event_data diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index c246274514dbf..37838e22a4002 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -47,11 +47,13 @@ class ServingScores(OpenAIServing): models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) async def _embedding_score( self, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 58d720474768b..2f258255d5f16 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -39,11 +39,13 @@ class OpenAIServingTokenization(OpenAIServing): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 0d6989fe91bfa..9ba58d4425221 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -32,13 +32,15 @@ class OpenAIServingTranscription(OpenAISpeechToText): *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="transcribe") + task_type="transcribe", + log_error_stack=log_error_stack) async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, @@ -88,13 +90,15 @@ class OpenAIServingTranslation(OpenAISpeechToText): *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="translate") + task_type="translate", + log_error_stack=log_error_stack) async def create_translation( self, audio_data: bytes, request: TranslationRequest, diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 01140a4bfea7e..1cbd7dba393f6 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -53,12 +53,14 @@ class OpenAISpeechToText(OpenAIServing): request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, task_type: Literal["transcribe", "translate"] = "transcribe", + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack) self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) @@ -200,7 +202,22 @@ class OpenAISpeechToText(OpenAIServing): for result_generator in list_result_generator: async for op in result_generator: text += op.outputs[0].text - return cast(T, response_class(text=text)) + + if self.task_type == "transcribe": + # add usage in TranscriptionResponse. + usage = { + "type": "duration", + # rounded up as per openAI specs + "seconds": int(math.ceil(duration_s)), + } + final_response = cast(T, response_class(text=text, + usage=usage)) + else: + # no usage in response for translation task + final_response = cast( + T, response_class(text=text)) # type: ignore[call-arg] + + return final_response except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 099e456aa486f..44aa1208a54c7 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -3,6 +3,7 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .deepseekv3_tool_parser import DeepSeekV3ToolParser +from .deepseekv31_tool_parser import DeepSeekV31ToolParser from .glm4_moe_tool_parser import Glm4MoeModelToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser @@ -18,6 +19,7 @@ from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser +from .seed_oss_tool_parser import SeedOssToolParser from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser @@ -35,11 +37,13 @@ __all__ = [ "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser", + "DeepSeekV31ToolParser", "xLAMToolParser", "MinimaxToolParser", "KimiK2ToolParser", "HunyuanA13BToolParser", "Glm4MoeModelToolParser", "Qwen3CoderToolParser", + "SeedOssToolParser", "Step3ToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py new file mode 100644 index 0000000000000..ff9188190f3f0 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("deepseek_v31") +class DeepSeekV31ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" + self.tool_calls_end_token: str = "<|tool▁calls▁end|>" + + self.tool_call_start_token: str = "<|tool▁call▁begin|>" + self.tool_call_end_token: str = "<|tool▁call▁end|>" + + self.tool_call_regex = re.compile( + r"<|tool▁call▁begin|>(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)<|tool▁call▁end|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)") + + self.stream_tool_call_name_regex = re.compile( + r"(?P<function_name>.*)<|tool▁sep|>") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "DeepSeek-V3.1 Tool parser could not locate tool call " + "start/end tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + tool_calls = [] + for match in function_call_tuples: + function_name, function_args = match + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_name, tool_args = current_tool_call_matches.groups() + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_name = current_tool_call_name_matches.groups() + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index da4760ad1b642..ac272b0c3b205 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -6,7 +6,7 @@ from typing import Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser): DeltaToolCall( index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True), diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 5508ba6a39408..824b100f357b5 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -10,7 +10,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index fcc5b7edda83f..ac517616a95b4 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index c7030d34d453e..a6ce33af6bd00 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -52,14 +52,51 @@ class Hermes2ProToolParser(ToolParser): raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction.") - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) - self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): - raise RuntimeError( - "Hermes 2 Pro Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + self.tool_call_start_token_ids = self.model_tokenizer.encode( + self.tool_call_start_token, add_special_tokens=False) + self.tool_call_end_token_ids = self.model_tokenizer.encode( + self.tool_call_end_token, add_special_tokens=False) + + self.tool_call_start_token_array = [ + self.model_tokenizer.decode([token_id]) + for token_id in self.tool_call_start_token_ids + ] + + self.tool_call_end_token_array = [ + self.model_tokenizer.decode([token_id]) + for token_id in self.tool_call_end_token_ids + ] + + self.buffered_delta_text = "" + + # Very simple idea: when encountering tokens like <, tool, _call, >, + # <, /, tool, _call, >, store them in a buffer. + # When the last token is encountered, empty the buffer and return it. + # If a token appears in an incorrect sequence while storing in the buffer, + # return the preceding buffer along with the token. + def tool_call_delta_buffer(self, delta_text: str): + # If the sequence of tool_call_start or tool_call_end tokens is not yet + # complete, fill the buffer with the token and return "". + if (delta_text in self.tool_call_start_token_array + or delta_text in self.tool_call_end_token_array): + # If delta_text is the last token of tool_call_start_token or + # tool_call_end_token, empty the buffer and return + # the buffered text + delta_text. + if (delta_text == self.tool_call_start_token_array[-1] + or delta_text == self.tool_call_end_token_array[-1]): + buffered_text = self.buffered_delta_text + self.buffered_delta_text = "" + return buffered_text + delta_text + else: + self.buffered_delta_text = self.buffered_delta_text + delta_text + return "" + else: + if self.buffered_delta_text: + buffered_text = self.buffered_delta_text + self.buffered_delta_text = "" + return buffered_text + delta_text + else: + return delta_text def extract_tool_calls( self, @@ -124,11 +161,23 @@ class Hermes2ProToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: + # 1. All tokens are parsed based on _text, not token_ids. + # 2. All incoming text data is processed by the tool_call_delta_buffer + # function for buffering before being used for parsing. + + delta_text = self.tool_call_delta_buffer(delta_text) + # If the last characters of previous_text + # match self.buffered_delta_text, remove only the matching part. + if (len(previous_text) >= len(self.buffered_delta_text) + and previous_text[-len(self.buffered_delta_text):] + == self.buffered_delta_text): + previous_text = previous_text[:-len(self.buffered_delta_text)] + current_text = previous_text + delta_text logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a - if self.tool_call_start_token_id not in current_token_ids: + if self.tool_call_start_token not in current_text: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) @@ -136,14 +185,12 @@ class Hermes2ProToolParser(ToolParser): # figure out where we are in the parsing by counting tool call # start & end tags - prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) - cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + prev_tool_start_count = previous_text.count( + self.tool_call_start_token) + prev_tool_end_count = previous_text.count(self.tool_call_end_token) + cur_tool_start_count = current_text.count( + self.tool_call_start_token) + cur_tool_end_count = current_text.count(self.tool_call_end_token) tool_call_portion = None text_portion = None @@ -260,7 +307,7 @@ class Hermes2ProToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 92004de030d14..6ef8fadf59ac5 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 66b483d8b0f66..3b41f6034704c 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -222,7 +222,7 @@ class JambaToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 194a144ad576e..31b19c8db4163 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -10,7 +10,7 @@ import regex as re from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 226309ef293a9..0fd62f0b6a7f1 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser): sent_tools.append({ "sent_name": False, "sent_arguments": "", - "id": random_tool_call_id(), + "id": make_tool_call_id(), }) while len(tool_ids) < tool_count: @@ -461,7 +461,8 @@ class MinimaxToolParser(ToolParser): i += 1 return boundaries - def _extract_tool_args(self, tool_content: str, args_match) -> str: + def _extract_tool_args(self, tool_content: str, + args_match: re.Match[str]) -> str: """ Extract tool arguments from tool content. diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 5501028cf36b8..85dd56213c6ac 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -8,7 +8,7 @@ from typing import Any, Optional import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation, @@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser): tool_calls: list[ToolCall] = [ ToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=FunctionCall( name=raw_function_call["name"], diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index cf4d0b231aee1..955813ddd3408 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import ast import json import uuid from collections.abc import Sequence @@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["qwen3_coder"]) +@ToolParserManager.register_module("qwen3_coder") class Qwen3CoderToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): @@ -30,6 +30,8 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] + # Override base class type - we use string IDs for tool calls + self.current_tool_id: Optional[str] = None # type: ignore self.streamed_args_for_tool: list[str] = [] # Sentinel tokens for streaming mode @@ -42,20 +44,6 @@ class Qwen3CoderToolParser(ToolParser): self.is_tool_call_started: bool = False self.failed_count: int = 0 - # Streaming state variables - self.current_tool_index: int = 0 - self.header_sent: bool = False - self.current_tool_string_id: Optional[str] = None - self.current_function_name: Optional[str] = None - self.current_param_name: Optional[str] = None - self.current_param_value: str = "" - self.param_count: int = 0 - self.in_param: bool = False - self.in_function: bool = False - self.accumulated_text: str = "" - self.json_started: bool = False - self.json_closed: bool = False - # Enhanced streaming state - reset for each new message self._reset_streaming_state() @@ -67,7 +55,8 @@ class Qwen3CoderToolParser(ToolParser): self.tool_call_function_regex = re.compile( r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) self.tool_call_parameter_regex = re.compile( - r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", + re.DOTALL) if not self.model_tokenizer: raise ValueError( @@ -84,8 +73,8 @@ class Qwen3CoderToolParser(ToolParser): "Qwen3 XML Tool parser could not locate tool call start/end " "tokens in the tokenizer!") - logger.debug("vLLM Successfully import tool parser %s !", - self.__class__.__name__) + logger.info("vLLM Successfully import tool parser %s !", + self.__class__.__name__) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" @@ -96,7 +85,7 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_index = 0 self.is_tool_call_started = False self.header_sent = False - self.current_tool_string_id = None + self.current_tool_id = None self.current_function_name = None self.current_param_name = None self.current_param_value = "" @@ -106,127 +95,122 @@ class Qwen3CoderToolParser(ToolParser): self.accumulated_text = "" self.json_started = False self.json_closed = False + # Store accumulated parameters for type conversion + self.accumulated_params = {} + self.streaming_request = None + + def _get_arguments_config( + self, func_name: str, + tools: Optional[list[ChatCompletionToolsParam]]) -> dict: + """Extract argument configuration for a function.""" + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not (hasattr( + config, "function") and hasattr(config.function, "name")): + continue + if config.type == "function" and config.function.name == func_name: + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def _convert_param_value(self, param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + """Convert parameter value based on its type in the schema.""" + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in the tool " + "parameters for tool '%s', directly returning the " + "string value.", param_name, func_name) + return param_value + + if isinstance(param_config[param_name], + dict) and "type" in param_config[param_name]: + param_type = str(param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif param_type.startswith("int") or param_type.startswith( + "uint") or param_type.startswith( + "long") or param_type.startswith( + "short") or param_type.startswith("unsigned"): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an " + "integer in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value = float(param_value) + return float_param_value if float_param_value - int( + float_param_value) != 0 else int(float_param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` or `false`) in tool '%s', degenerating to " + "false.", param_value, param_name, func_name) + return param_value == "true" + else: + if param_type in ["object", "array", "arr" + ] or param_type.startswith( + "dict") or param_type.startswith("list"): + try: + param_value = json.loads(param_value) + return param_value + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "parsed with json.loads in tool '%s', will try " + "other methods to parse it.", param_value, param_name, + func_name) + try: + param_value = ast.literal_eval(param_value) # safer + except (ValueError, SyntaxError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "converted via Python `ast.literal_eval()` in tool " + "'%s', degenerating to string.", param_value, param_name, + func_name) + return param_value def _parse_xml_function_call( self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] ) -> Optional[ToolCall]: - def get_arguments_config(func_name: str) -> dict: - if tools is None: - return {} - for config in tools: - if not hasattr(config, "type") or not ( - hasattr(config, "function") - and hasattr(config.function, "name")): - continue - if (config.type == "function" - and config.function.name == func_name): - if not hasattr(config.function, "parameters"): - return {} - params = config.function.parameters - if isinstance(params, dict) and "properties" in params: - return params["properties"] - elif isinstance(params, dict): - return params - else: - return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) - return {} - - def convert_param_value(param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: - # Handle null value for any type - if param_value.lower() == "null": - return None - - converted_value: Any - - if param_name not in param_config: - if param_config != {}: - logger.warning( - "Parsed parameter '%s' is not defined in the tool " - "parameters for tool '%s', directly returning the " - "string value.", param_name, func_name) - return param_value - - if (isinstance(param_config[param_name], dict) - and "type" in param_config[param_name]): - param_type = str( - param_config[param_name]["type"]).strip().lower() - else: - param_type = "string" - if param_type in [ - "string", "str", "text", "varchar", "char", "enum" - ]: - return param_value - elif (param_type.startswith("int") or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned")): - try: - converted_value = int(param_value) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not an " - "integer in tool '%s', degenerating to string.", - param_value, param_name, func_name) - return param_value - elif (param_type.startswith("num") - or param_type.startswith("float")): - try: - float_param_value = float(param_value) - converted_value = (float_param_value if float_param_value - - int(float_param_value) != 0 else - int(float_param_value)) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a float " - "in tool '%s', degenerating to string.", param_value, - param_name, func_name) - return param_value - elif param_type in ["boolean", "bool", "binary"]: - param_value = param_value.lower() - if param_value not in ["true", "false"]: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "boolean (`true` of `false`) in tool '%s', " - "degenerating to false.", param_value, param_name, - func_name) - return param_value == "true" - else: - if param_type == "object" or param_type.startswith("dict"): - try: - converted_value = json.loads(param_value) - return converted_value - except json.JSONDecodeError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "valid JSON object in tool '%s', will try other " - "methods to parse it.", param_value, param_name, - func_name) - try: - converted_value = eval(param_value) - return converted_value - except Exception: - logger.warning( - "Parsed value '%s' of parameter '%s' cannot be " - "converted via Python `eval()` in tool '%s', " - "degenerating to string.", param_value, param_name, - func_name) - return param_value - # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] - param_config = get_arguments_config(function_name) + param_config = self._get_arguments_config(function_name, tools) parameters = function_call_str[end_index + 1:] param_dict = {} - for match in self.tool_call_parameter_regex.findall(parameters): - match_text = match[0] if match[0] else match[1] + for match_text in self.tool_call_parameter_regex.findall(parameters): idx = match_text.index(">") param_name = match_text[:idx] param_value = str(match_text[idx + 1:]) @@ -236,7 +220,7 @@ class Qwen3CoderToolParser(ToolParser): if param_value.endswith("\n"): param_value = param_value[:-1] - param_dict[param_name] = convert_param_value( + param_dict[param_name] = self._convert_param_value( param_value, param_name, param_config, function_name) return ToolCall( type="function", @@ -289,8 +273,7 @@ class Qwen3CoderToolParser(ToolParser): for function_call_str in function_calls ] - # Populate prev_tool_call_arr for serving layer to set - # finish_reason + # Populate prev_tool_call_arr for serving layer to set finish_reason self.prev_tool_call_arr.clear() # Clear previous calls for tool_call in tool_calls: if tool_call: @@ -303,8 +286,8 @@ class Qwen3CoderToolParser(ToolParser): # Extract content before tool calls content_index = model_output.find(self.tool_call_start_token) - content_index = (content_index if content_index >= 0 else - model_output.find(self.tool_call_prefix)) + idx = model_output.find(self.tool_call_prefix) + content_index = content_index if content_index >= 0 else idx content = model_output[:content_index] # .rstrip() return ExtractedToolCallInformation( @@ -329,13 +312,16 @@ class Qwen3CoderToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # If no delta text, return None unless it's an EOS token after tool - # calls + # Store request for type conversion + if not previous_text: + self._reset_streaming_state() + self.streaming_request = request + + # If no delta text, return None unless it's an EOS token after tools if not delta_text: # Check if this is an EOS token after all tool calls are complete - # We check for tool calls in the text even if is_tool_call_started - # is False because it might have been reset after processing all - # tools + # Check for tool calls in text even if is_tool_call_started + # is False (might have been reset after processing all tools) if (delta_token_ids and self.tool_call_end_token_id not in delta_token_ids): # Count complete tool calls @@ -344,24 +330,19 @@ class Qwen3CoderToolParser(ToolParser): # If we have completed tool calls and populated # prev_tool_call_arr - if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0): + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: # Check if all tool calls are closed - open_calls = ( - current_text.count(self.tool_call_start_token) - - current_text.count(self.tool_call_end_token)) + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) if open_calls == 0: - # Return empty delta message to allow finish_reason - # processing + # Return empty delta for finish_reason processing return DeltaMessage(content="") elif not self.is_tool_call_started and current_text: # This is a regular content response that's now complete return DeltaMessage(content="") return None - # Check if this is the first call (reset state if needed) - if not previous_text: - self._reset_streaming_state() - # Update accumulated text self.accumulated_text = current_text @@ -376,11 +357,11 @@ class Qwen3CoderToolParser(ToolParser): self.param_count = 0 self.json_started = False self.json_closed = False + self.accumulated_params = {} # Check if there are more tool calls - tool_starts_count = current_text.count( - self.tool_call_start_token) - if self.current_tool_index >= tool_starts_count: + tool_starts = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts: # No more tool calls self.is_tool_call_started = False # Continue processing next tool @@ -417,20 +398,20 @@ class Qwen3CoderToolParser(ToolParser): # We're in a tool call, find the current tool call portion # Need to find the correct tool call based on current_tool_index - tool_starts: list[int] = [] + tool_start_positions: list[int] = [] idx = 0 while True: idx = current_text.find(self.tool_call_start_token, idx) if idx == -1: break - tool_starts.append(idx) + tool_start_positions.append(idx) idx += len(self.tool_call_start_token) - if self.current_tool_index >= len(tool_starts): + if self.current_tool_index >= len(tool_start_positions): # No more tool calls to process yet return None - tool_start_idx = tool_starts[self.current_tool_index] + tool_start_idx = tool_start_positions[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) @@ -443,19 +424,19 @@ class Qwen3CoderToolParser(ToolParser): # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) func_end = tool_text.find(">", func_start) if func_end != -1: # Found complete function name self.current_function_name = tool_text[func_start:func_end] - self.current_tool_string_id = self._generate_tool_call_id() + self.current_tool_id = self._generate_tool_call_id() self.header_sent = True self.in_function = True - # IMPORTANT: Add to prev_tool_call_arr immediately when we - # detect a tool call. This ensures + # IMPORTANT: Add to prev_tool_call_arr immediately when + # we detect a tool call. This ensures # finish_reason="tool_calls" even if parsing isn't complete already_added = any( tool.get("name") == self.current_function_name @@ -471,7 +452,7 @@ class Qwen3CoderToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, - id=self.current_tool_string_id, + id=self.current_tool_id, function=DeltaFunctionCall( name=self.current_function_name, arguments=""), type="function", @@ -501,10 +482,11 @@ class Qwen3CoderToolParser(ToolParser): # Close JSON self.json_closed = True - # Extract the complete tool call to update prev_tool_call_arr - # with final arguments. Find the function content - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + # Extract complete tool call to update + # prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: @@ -512,15 +494,17 @@ class Qwen3CoderToolParser(ToolParser): # Parse to get the complete arguments try: parsed_tool = self._parse_xml_function_call( - func_content, request.tools if request else None) + func_content, self.streaming_request.tools + if self.streaming_request else None) if parsed_tool: - # Update existing entry in prev_tool_call_arr with - # complete arguments + # Update existing entry in + # prev_tool_call_arr with complete args for i, tool in enumerate(self.prev_tool_call_arr): - if (tool.get("name") == - parsed_tool.function.name): - self.prev_tool_call_arr[i]["arguments"] = ( - parsed_tool.function.arguments) + if tool.get( + "name") == parsed_tool.function.name: + args = parsed_tool.function.arguments + self.prev_tool_call_arr[i][ + "arguments"] = args break except Exception: pass # Ignore parsing errors during streaming @@ -535,73 +519,110 @@ class Qwen3CoderToolParser(ToolParser): # Reset state for next tool self.in_function = False self.json_closed = True + self.accumulated_params = {} return result # Look for parameters - # Count how many complete parameters we have processed - complete_params = tool_text.count(self.parameter_end_token) + # Find all parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) # Check if we should start a new parameter - if not self.in_param and self.param_count < complete_params: - # Find the unprocessed parameter - # Count parameter starts - param_starts = [] - idx = 0 - while True: - idx = tool_text.find(self.parameter_prefix, idx) - if idx == -1: - break - param_starts.append(idx) - idx += len(self.parameter_prefix) + if (not self.in_param and self.param_count < len(param_starts) + and len(param_starts) > self.param_count): + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] - if len(param_starts) > self.param_count: - # Process the next parameter - param_idx = param_starts[self.param_count] - param_start = param_idx + len(self.parameter_prefix) - remaining = tool_text[param_start:] + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] - if ">" in remaining: - # We have the complete parameter name - name_end = remaining.find(">") - self.current_param_name = remaining[:name_end] + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] - # Find the parameter value - value_start = param_start + name_end + 1 - value_text = tool_text[value_start:] - if value_text.startswith("\n"): - value_text = value_text[1:] + # Find where this parameter ends + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # No closing tag, look for next parameter or + # function end + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.function_end_token) - # Find where this parameter ends - param_end_idx = value_text.find( - self.parameter_end_token) - if param_end_idx != -1: - # Complete parameter found - param_value = value_text[:param_end_idx] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - # Build complete JSON fragment for this parameter - if self.param_count == 0: - json_fragment = ( - '"' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + if next_param_idx != -1 and (func_end_idx == -1 + or next_param_idx + < func_end_idx): + param_end_idx = next_param_idx + elif func_end_idx != -1: + param_end_idx = func_end_idx + else: + # Neither found, check if tool call is complete + if self.tool_call_end_token in tool_text: + # Tool call is complete, so parameter + # must be complete too. Use all + # remaining text before function end + param_end_idx = len(value_text) else: - json_fragment = ( - ', "' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + # Still streaming, wait for more content + return None - self.param_count += 1 + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=json_fragment), - ) - ]) + # Store raw value for later processing + self.accumulated_params[ + self.current_param_name] = param_value - # Continue parameter value + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request else None) + + # Convert param value to appropriate type + converted_value = self._convert_param_value( + param_value, self.current_param_name, param_config, + self.current_function_name or "") + + # Build JSON fragment based on the converted type + # Use json.dumps to properly serialize the value + serialized_value = json.dumps(converted_value, + ensure_ascii=False) + + if self.param_count == 0: + json_fragment = (f'"{self.current_param_name}": ' + f'{serialized_value}') + else: + json_fragment = (f', "{self.current_param_name}": ' + f'{serialized_value}') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value - Not used in the current implementation + # since we process complete parameters above if self.in_param: if self.parameter_end_token in delta_text: # End of parameter @@ -613,25 +634,42 @@ class Qwen3CoderToolParser(ToolParser): gt_idx = value_chunk.find(">") value_chunk = value_chunk[gt_idx + 1:] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith( + "\n"): value_chunk = value_chunk[1:] - # Calculate incremental JSON + # Store complete value full_value = self.current_param_value + value_chunk - prev_escaped = (json.dumps(self.current_param_value)[1:-1] - if self.current_param_value else "") - full_escaped = json.dumps(full_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + self.accumulated_params[ + self.current_param_name] = full_value + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request else None) + + # Convert the parameter value to the appropriate type + converted_value = self._convert_param_value( + full_value, self.current_param_name or "", + param_config, self.current_function_name or "") + + # Serialize the converted value + serialized_value = json.dumps(converted_value, + ensure_ascii=False) + + # Since we've been streaming the quoted version, + # we need to close it properly + # This is complex - for now just complete the value self.in_param = False self.current_param_value = "" + # Just close the current parameter string return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, function=DeltaFunctionCall( - arguments=delta_escaped + '"'), + arguments='"'), # Close the string quote ) ]) else: @@ -643,18 +681,18 @@ class Qwen3CoderToolParser(ToolParser): gt_idx = value_chunk.find(">") value_chunk = value_chunk[gt_idx + 1:] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith( + "\n"): value_chunk = value_chunk[1:] if value_chunk: # Stream the escaped delta - prev_escaped = (json.dumps( - self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = json.dumps( + self.current_param_value, ensure_ascii=False + )[1:-1] if self.current_param_value else "" self.current_param_value += value_chunk - full_escaped = json.dumps( - self.current_param_value)[1:-1] + full_escaped = json.dumps(self.current_param_value, + ensure_ascii=False)[1:-1] delta_escaped = full_escaped[len(prev_escaped):] if delta_escaped: @@ -666,4 +704,4 @@ class Qwen3CoderToolParser(ToolParser): ) ]) - return None + return None \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py new file mode 100644 index 0000000000000..95458f07ff2a2 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py @@ -0,0 +1,679 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from qwen3coder xml parser, All rights reserved. +# ruff: noqa: E501 + +import ast +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("seed_oss") +class SeedOssToolParser(ToolParser): + TOOL_CALL_START = "<seed:tool_call>" + TOOL_CALL_END = "</seed:tool_call>" + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # --- streaming state --- + self._reset_streaming_state() + self.prev_tool_call_arr: list[dict] = [] + + self.tool_call_start_token: str = self.TOOL_CALL_START + self.tool_call_end_token: str = self.TOOL_CALL_END + # Sentinel tokens for streaming mode + self.tool_call_prefix: str = "<function=" + self.function_end_token: str = "</function>" + self.parameter_prefix: str = "<parameter=" + self.parameter_end_token: str = "</parameter>" + self.think_start_token: str = "<seed:think>" + self.think_end_token: str = "</seed:think>" + self.is_tool_call_started: bool = False + self.is_thinking_end: bool = False + self.failed_count: int = 0 + self._reset_streaming_state() + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + raise RuntimeError( + "Seed_Oss XML parser: tokenizer did not include " + "<seed:tool_call> or its closing tag.") + + tool_start_re = re.escape(self.tool_call_start_token) + tool_end_re = re.escape(self.tool_call_end_token) + + self.tool_call_complete_regex = re.compile( + rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL) + self.tool_call_regex = re.compile( + rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", + re.DOTALL) + + self.tool_call_function_regex = re.compile( + r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) + self.tool_call_parameter_regex = re.compile( + r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + + logger.info("vLLM Seed-Oss XML tool parser loaded (%s).", + self.__class__.__name__) + + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_id = -1 + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.accumulated_text = "" + self.json_started = False + self.json_closed = False + + def _parse_xml_function_call( + self, function_call_str: str, + tools: Optional[list[ChatCompletionToolsParam]] + ) -> Optional[ToolCall]: + + def get_arguments_config(func_name: str) -> dict: + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not ( + hasattr(config, "function") + and hasattr(config.function, "name")): + continue + if (config.type == "function" + and config.function.name == func_name): + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def convert_param_value(param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in " + "the tool parameters for tool '%s', " + "directly returning the string value.", param_name, + func_name) + return param_value + + if (isinstance(param_config[param_name], dict) + and "type" in param_config[param_name]): + param_type = str( + param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in [ + "string", "str", "text", "varchar", "char", "enum" + ]: + return param_value + elif (param_type.startswith("int") or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned")): + try: + param_value = int(param_value) # type: ignore + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an integer in tool " + "'%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type.startswith("num") or param_type.startswith( + "float"): + try: + float_param_value = float(param_value) + param_value = float_param_value if float_param_value - int( + float_param_value) != 0 else int( + float_param_value) # type: ignore + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float in tool " + "'%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` of `false`) in tool '%s', degenerating to false.", + param_value, param_name, func_name) + return param_value == "true" + else: + if param_type == "object" or param_type.startswith("dict"): + try: + param_value = json.loads(param_value) + return param_value + except (ValueError, TypeError, json.JSONDecodeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a valid JSON " + "object in tool '%s', will try other methods to parse it.", + param_value, param_name, func_name) + try: + param_value = ast.literal_eval(param_value) + except (ValueError, SyntaxError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be converted via " + "Python `ast.literal_eval()` in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + + # Extract function name + end_index = function_call_str.index(">") + function_name = function_call_str[:end_index] + param_config = get_arguments_config(function_name) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match in self.tool_call_parameter_regex.findall(parameters): + match_text = match[0] if match[0] else match[1] + idx = match_text.index(">") + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = convert_param_value( + param_value, param_name, param_config, function_name) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(param_dict, + ensure_ascii=False)), + ) + + def _get_function_calls(self, model_output: str) -> list[str]: + # Find all tool calls + matched_ranges = self.tool_call_regex.findall(model_output) + raw_tool_calls = [ + match[0] if match[0] else match[1] for match in matched_ranges + ] + + # Back-off strategy if no tool_call tags found + if len(raw_tool_calls) == 0: + raw_tool_calls = [model_output] + + raw_function_calls = [] + for tool_call in raw_tool_calls: + raw_function_calls.extend( + self.tool_call_function_regex.findall(tool_call)) + + function_calls = [ + match[0] if match[0] else match[1] for match in raw_function_calls + ] + return function_calls + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # Quick check to avoid unnecessary processing + if self.tool_call_prefix not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + # Check if both think start and end tokens are present + if (self.think_start_token in model_output + and self.think_end_token in model_output): + # Find the position of think end token + think_end_index = model_output.find(self.think_end_token) + len( + self.think_end_token) + # Extract content after think end token + result_content = model_output[think_end_index:] + thinking_content = model_output[:think_end_index] + else: + thinking_content = "" + result_content = model_output + + try: + function_calls = self._get_function_calls(result_content) + if len(function_calls) == 0: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_calls = [ + self._parse_xml_function_call(function_call_str, request.tools) + for function_call_str in function_calls + ] + + # Populate prev_tool_call_arr for serving layer to set finish_reason + self.prev_tool_call_arr.clear() # Clear previous calls + for tool_call in tool_calls: + if tool_call: + self.prev_tool_call_arr.append({ + "name": + tool_call.function.name, + "arguments": + tool_call.function.arguments, + }) + + # Extract content before tool calls + tool_call_start_index = result_content.find( + self.tool_call_start_token) + tool_call_start_index = ( + tool_call_start_index if tool_call_start_index >= 0 else + result_content.find(self.tool_call_prefix)) + content = thinking_content + result_content[:tool_call_start_index] + + return ExtractedToolCallInformation( + tools_called=(len(tool_calls) > 0), + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + # If no delta text, return None unless + # it's an EOS token after tool calls + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + # We check for tool calls in the text even if is_tool_call_started + # is False because it might have been reset after processing all tools + if (delta_token_ids + and self.tool_call_end_token_id not in delta_token_ids): + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text)) + + # If we have completed tool calls and populated prev_tool_call_arr + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) + if open_calls == 0: + # Return empty delta message to allow finish_reason processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Check if this is the first call (reset state if needed) + if not previous_text: + self._reset_streaming_state() + + # Update accumulated text + self.accumulated_text = current_text + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + tool_ends = current_text.count(self.tool_call_end_token) + if tool_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + + # Check if there are more tool calls + if self.current_tool_index >= current_text.count( + self.tool_call_start_token): + # No more tool calls + self.is_tool_call_started = False + # Continue processing next tool + return None + + # Check if end thinking + if (not self.is_thinking_end + and (self.think_end_token_id in delta_token_ids + or self.think_end_token in delta_text)): + self.is_thinking_end = True + + # If thinking hasn't ended yet, don't process any tool calls + if not self.is_thinking_end: + return DeltaMessage(content=delta_text) + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if (self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text): + self.is_tool_call_started = True + # Return any content before the tool call + if self.tool_call_start_token in delta_text: + content_before = delta_text[:delta_text.index( + self.tool_call_start_token)] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if (current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == ""): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + # Count tool calls we've seen vs processed + tool_starts_count = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # We're in a tool call, find the current tool call portion + # Need to find the correct tool call based on current_tool_index + # Only process tool calls after think_end_token + think_end_index = current_text.find(self.think_end_token) + len( + self.think_end_token + ) if self.think_end_token in current_text else 0 + tool_starts: list[int] = [] + idx = think_end_index + while True: + idx = current_text.find(self.tool_call_start_token, idx) + if idx == -1: + break + tool_starts.append(idx) + idx += len(self.tool_call_start_token) + + if self.current_tool_index >= len(tool_starts): + # No more tool calls to process yet + return None + + tool_start_idx = tool_starts[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + tool_end_idx = current_text.find(self.tool_call_end_token, + tool_start_idx) + if tool_end_idx == -1: + tool_text = current_text[tool_start_idx:] + else: + tool_text = current_text[tool_start_idx:tool_end_idx + + len(self.tool_call_end_token)] + + # Looking for function header + if not self.header_sent: + if self.tool_call_prefix in tool_text: + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + self.current_function_name = tool_text[func_start:func_end] + self.current_tool_id = self._generate_tool_call_id( + ) # type: ignore + self.header_sent = True + self.in_function = True + + # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call + # This ensures finish_reason="tool_calls" even if parsing isn't complete + already_added = any( + tool.get("name") == self.current_function_name + for tool in self.prev_tool_call_arr) + if not already_added: + self.prev_tool_call_arr.append({ + "name": self.current_function_name, + "arguments": + "{}", # Placeholder, will be updated later + }) + + # Send header with function info + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments=""), + type="function", + ) + ]) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if (not self.json_started + and self.parameter_prefix not in delta_text): + self.json_started = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ]) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.function_end_token in tool_text: + # Close JSON + self.json_closed = True + + # Extract the complete tool call to update prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) + func_content_end = tool_text.find(self.function_end_token, + func_start) + if func_content_end != -1: + func_content = tool_text[func_start:func_content_end] + # Parse to get the complete arguments + try: + parsed_tool = self._parse_xml_function_call( + func_content, request.tools if request else None) + if parsed_tool: + # Update existing entry in prev_tool_call_arr with complete arguments + for i, tool in enumerate(self.prev_tool_call_arr): + if tool.get( + "name") == parsed_tool.function.name: + self.prev_tool_call_arr[i]["arguments"] = ( + parsed_tool.function.arguments) + break + except Exception: + logger.warning( + "Failed to parse tool arguments during streaming.", + exc_info=True) + + result = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ]) + + # Reset state for next tool + self.in_function = False + self.json_closed = True + + return result + + # Look for parameters + # Count how many complete parameters we have processed + complete_params = tool_text.count(self.parameter_end_token) + + # Check if we should start a new parameter + if not self.in_param and self.param_count < complete_params: + # Find the unprocessed parameter + # Count parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + if len(param_starts) > self.param_count: + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find( + self.parameter_end_token) + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Build complete JSON fragment for this parameter + if self.param_count == 0: + json_fragment = ( + '"' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + else: + json_fragment = ( + ', "' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value + if self.in_param: + if self.parameter_end_token in delta_text: + # End of parameter + end_idx = delta_text.find(self.parameter_end_token) + value_chunk = delta_text[:end_idx] + + # Skip past > if at start + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if not self.current_param_value and value_chunk.startswith( + "\n"): + value_chunk = value_chunk[1:] + + # Calculate incremental JSON + full_value = self.current_param_value + value_chunk + prev_escaped = (json.dumps(self.current_param_value)[1:-1] + if self.current_param_value else "") + full_escaped = json.dumps(full_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + self.in_param = False + self.current_param_value = "" + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + '"'), + ) + ]) + else: + # Continue accumulating value + value_chunk = delta_text + + # Handle first chunk after param name + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if not self.current_param_value and value_chunk.startswith( + "\n"): + value_chunk = value_chunk[1:] + + if value_chunk: + # Stream the escaped delta + prev_escaped = (json.dumps( + self.current_param_value)[1:-1] + if self.current_param_value else "") + self.current_param_value += value_chunk + full_escaped = json.dumps( + self.current_param_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + if delta_escaped: + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped), + ) + ]) + + return None diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 321718b1c950b..87cd413b37200 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser): function_name = name_match.group(1) # The test expects us to send just the name first - tool_id = random_tool_call_id() + tool_id = make_tool_call_id() delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=0, diff --git a/vllm/envs.py b/vllm/envs.py index 82084d1fc5ae1..35735b552575b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +import json import os import sys import tempfile @@ -42,7 +43,6 @@ if TYPE_CHECKING: VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None - VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -131,7 +131,9 @@ if TYPE_CHECKING: VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM_E8M0: bool = True + VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False + VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_FLASHINFER_MOE_BACKEND: str = "throughput" @@ -159,8 +161,10 @@ if TYPE_CHECKING: VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None @@ -465,11 +469,6 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, - # If set, vllm will force flashinfer to use tensor cores; - # otherwise will use heuristic based on model architecture. - "VLLM_FLASHINFER_FORCE_TENSOR_CORES": - lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), - # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), @@ -667,11 +666,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), - # Enables torch profiler if set. Path to the directory where torch profiler - # traces are saved. Note that it must be an absolute path. + # Enables torch profiler if set. + # Both AsyncLLM's CPU traces as well as workers' + # traces (CPU & GPU) will be saved under this directory. + # Note that it must be an absolute path. "VLLM_TORCH_PROFILER_DIR": lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os - .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + .path.abspath(os.path.expanduser(os.getenv( + "VLLM_TORCH_PROFILER_DIR", ".")))), # Enable torch profiler to record shapes if set # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will @@ -953,9 +955,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. - # E8M0 is faster on B200 but may reduce accuracy. "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), + # TODO(wentao): unify the two E8M0 flags after verifying the correctness. + # Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs. + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -964,6 +969,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))), + # Whether to use fused grouped_topk used for MoE expert selection. + "VLLM_USE_FUSED_MOE_GROUPED_TOPK": + lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), @@ -1042,6 +1051,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), + # Specifies the thresholds of the communicated tensor sizes under which + # vllm should use flashinfer fused allreduce. The variable should be a + # JSON with the following format: + # { <world size>: <max size in mb> } + # Unspecified world sizes will fallback to + # { 2: 64, 4: 1, <everything else>: 0.5 } + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": + lambda: json.loads(os.getenv( + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")), + # MoE routing strategy selector. # See `RoutingSimulator.get_available_strategies()` # for available # strategies. @@ -1108,6 +1127,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_TRTLLM_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set, it means we pre-downloaded cubin files and flashinfer will + # read the cubin files directly. + "VLLM_HAS_FLASHINFER_CUBIN": + lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), + # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. # Otherwise, uses the first available of: flashinfer cutlass GEMM, # vllm cutlass GEMM, marlin GEMM. @@ -1153,6 +1177,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), + # Whether to use pytorch symmetric memory for allreduce + "VLLM_ALLREDUCE_USE_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), + # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), @@ -1199,14 +1227,6 @@ def compute_hash() -> str: affect the choice of different kernels or attention backends should also be included in the factors list. """ - factors: list[Any] = [] - - # summarize environment variables - def factorize(name: str): - if __getattr__(name): - factors.append(__getattr__(name)) - else: - factors.append("None") # The values of envs may affects the computation graph. # TODO(DefTruth): hash all environment variables? @@ -1221,11 +1241,47 @@ def compute_hash() -> str: "VLLM_DP_SIZE", "VLLM_USE_STANDALONE_COMPILE", "VLLM_FUSED_MOE_CHUNK_SIZE", + "VLLM_FLASHINFER_MOE_BACKEND", + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION", + "VLLM_USE_AITER_UNIFIED_ATTENTION", + "VLLM_ATTENTION_BACKEND", + "VLLM_USE_FLASHINFER_SAMPLER", + "VLLM_DISABLED_KERNELS", + "VLLM_USE_DEEP_GEMM", + "VLLM_USE_DEEP_GEMM_E8M0", + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "VLLM_USE_TRTLLM_FP4_GEMM", + "VLLM_USE_FUSED_MOE_GROUPED_TOPK", + "VLLM_USE_FLASHINFER_MOE_FP8", + "VLLM_USE_FLASHINFER_MOE_FP4", + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", + "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", + "VLLM_USE_CUDNN_PREFILL", + "VLLM_USE_TRTLLM_ATTENTION", + "VLLM_ROCM_USE_AITER", + "VLLM_ROCM_USE_AITER_PAGED_ATTN", + "VLLM_ROCM_USE_AITER_LINEAR", + "VLLM_ROCM_USE_AITER_MOE", + "VLLM_ROCM_USE_AITER_RMSNORM", + "VLLM_ROCM_USE_AITER_MLA", + "VLLM_ROCM_USE_AITER_MHA", + "VLLM_ROCM_USE_SKINNY_GEMM", + "VLLM_ROCM_FP8_PADDING", + "VLLM_ROCM_MOE_PADDING", + "VLLM_ROCM_CUSTOM_PAGED_ATTN", + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", + "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", + "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", ] for key in environment_variables_to_hash: - if key in environment_variables: - factorize(key) + # if this goes out of sync with environment_variables, + # it's not a user error, it's a bug + assert key in environment_variables, \ + "Please update environment_variables_to_hash in envs.py" + + factors = [ + environment_variables[key]() for key in environment_variables_to_hash + ] hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py index 852c8f5cffa0c..4ce6d8dfad2cc 100644 --- a/vllm/executor/msgspec_utils.py +++ b/vllm/executor/msgspec_utils.py @@ -4,11 +4,12 @@ from array import array from typing import Any, Type +from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE def encode_hook(obj: Any) -> Any: - """Custom msgspec enc hook that supports array types. + """Custom msgspec enc hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder """ @@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any: f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " f"Given array has a type code of {obj.typecode}.") return obj.tobytes() + if isinstance(obj, MultiModalKwargs): + return dict(obj) def decode_hook(type: Type, obj: Any) -> Any: - """Custom msgspec dec hook that supports array types. + """Custom msgspec dec hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder """ @@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any: deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) deserialized.frombytes(obj) return deserialized + if type is MultiModalKwargs: + return MultiModalKwargs(obj) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index de5dc0876651a..f0d0cab3df3d9 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -11,6 +11,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -32,12 +33,14 @@ class InputPreprocessor: model_config: ModelConfig, tokenizer: Optional[TokenizerGroup], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer self.mm_registry = mm_registry + self.mm_processor_cache = mm_processor_cache def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: @@ -254,7 +257,6 @@ class InputPreprocessor: mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, @@ -262,8 +264,11 @@ class InputPreprocessor: """ tokenizer = self._get_mm_tokenizer(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) if mm_processor_kwargs is None: mm_processor_kwargs = {} @@ -271,8 +276,7 @@ class InputPreprocessor: return mm_processor.apply(prompt, mm_data, hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes) + tokenization_kwargs=tokenization_kwargs) async def _process_multimodal_async( self, @@ -281,7 +285,6 @@ class InputPreprocessor: mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Async version of @@ -289,16 +292,19 @@ class InputPreprocessor: """ tokenizer = await self._get_mm_tokenizer_async(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) + if mm_processor_kwargs is None: mm_processor_kwargs = {} return mm_processor.apply(prompt, mm_data, hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes) + tokenization_kwargs=tokenization_kwargs) def _process_embeds( self, @@ -335,7 +341,6 @@ class InputPreprocessor: parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = parsed_content["prompt_token_ids"] token_type_ids = parsed_content.get("token_type_ids") @@ -348,7 +353,6 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) else: inputs = token_inputs( @@ -366,7 +370,6 @@ class InputPreprocessor: parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = parsed_content["prompt_token_ids"] token_type_ids = parsed_content.get("token_type_ids") @@ -379,7 +382,6 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) else: inputs = token_inputs( @@ -397,7 +399,6 @@ class InputPreprocessor: parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -409,7 +410,6 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) else: prompt_token_ids = self._tokenize_prompt( @@ -432,7 +432,6 @@ class InputPreprocessor: parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -444,7 +443,6 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) else: prompt_token_ids = await self._tokenize_prompt_async( @@ -467,7 +465,6 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. @@ -476,7 +473,6 @@ class InputPreprocessor: * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts - * return_mm_hashes: whether to return multimodal hashes Returns: @@ -490,21 +486,18 @@ class InputPreprocessor: return self._process_tokens( parsed["content"], lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) if parsed["type"] == "text": return self._process_text( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) if parsed["type"] == "str": return self._process_text( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) assert_never(parsed) @@ -514,7 +507,6 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> SingletonInputs: """ Async version of @@ -528,21 +520,18 @@ class InputPreprocessor: return await self._process_tokens_async( parsed["content"], lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) if parsed["type"] == "text": return await self._process_text_async( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) if parsed["type"] == "str": return await self._process_text_async( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) assert_never(parsed) @@ -785,7 +774,6 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> DecoderOnlyInputs: """ For decoder-only models: @@ -796,7 +784,6 @@ class InputPreprocessor: * prompt: input prompt * lora_request - * return_mm_hashes Returns: @@ -807,7 +794,6 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -817,7 +803,6 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> DecoderOnlyInputs: """ Async version of @@ -827,7 +812,6 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -837,17 +821,15 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder + # input prompts to encoder & decoder. return self._process_encoder_decoder_prompt( - prompt, tokenization_kwargs) + prompt, + tokenization_kwargs, + ) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " @@ -858,7 +840,6 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) async def preprocess_async( @@ -866,19 +847,18 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> ProcessorInputs: """ Async version of [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. """ if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - return await self._process_encoder_decoder_prompt_async(prompt) + # input prompts to encoder & decoder. + return await self._process_encoder_decoder_prompt_async( + prompt, + tokenization_kwargs, + ) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " @@ -889,5 +869,8 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) + + def clear_cache(self) -> None: + if self.mm_processor_cache is not None: + self.mm_processor_cache.clear_cache() diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index dc3236508348f..f0b392e9767ae 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -223,23 +223,29 @@ class InputRegistry: The model is identified by ``model_config``. """ # Avoid circular import + from vllm.multimodal.cache import processor_only_cache_from_config from vllm.sequence import SequenceData if not model_config.is_multimodal_model: seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) return DummyData(seq_data=seq_data) + cache = processor_only_cache_from_config(model_config, mm_registry) + # Encoder dummy data does not contain multi-modal data if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data( - model_config, seq_len) + enc_data = mm_registry.get_encoder_dummy_data(model_config, + seq_len, + cache=cache) seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) return DummyData(seq_data=seq_data) - dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) + dec_data = mm_registry.get_decoder_dummy_data(model_config, + seq_len, + cache=cache) return DummyData( seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), - multi_modal_data=dec_data.multi_modal_data, + multi_modal_data=dec_data.multi_modal_data.get_data(), multi_modal_placeholders=dec_data.multi_modal_placeholders, ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index de5933d6d41e5..24a05d310d108 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # GPTQ/AWQ elif hasattr(base_layer, "qweight"): return base_layer.qweight.device - # marlin - elif hasattr(base_layer, "B"): - return base_layer.B.device # HQQ marlin elif hasattr(base_layer, "W_q"): return base_layer.W_q.device diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e6b19d4748f44..3072047a2606c 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -207,6 +207,7 @@ class LoRAModel(AdapterModel): """ lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") new_embeddings_tensor_path = os.path.join( lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, @@ -255,9 +256,10 @@ class LoRAModel(AdapterModel): check_unexpected_modules(f) for module in f.keys(): # noqa tensors[module] = f.get_tensor(module) - elif os.path.isfile(lora_bin_file_path): - # When a bin file is provided, we rely on config to find unexpected - # modules. + elif os.path.isfile(lora_bin_file_path) or os.path.isfile( + lora_pt_file_path): + # When a bin/pt file is provided, we rely on config to find + # unexpected modules. unexpected_modules = [] target_modules = peft_helper.target_modules if not isinstance(target_modules, list): @@ -279,7 +281,10 @@ class LoRAModel(AdapterModel): f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." f" Please verify that the loaded LoRA module is correct") - tensors = torch.load(lora_bin_file_path, + lora_file_path = (lora_bin_file_path + if os.path.isfile(lora_bin_file_path) else + lora_pt_file_path) + tensors = torch.load(lora_file_path, map_location=device, weights_only=True) else: diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 7ce44174ead6d..f3248589abc47 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -239,6 +239,35 @@ class GeluAndMul(CustomOp): return f'approximate={repr(self.approximate)}' +@CustomOp.register("swigluoai_and_mul") +class SwigluOAIAndMul(CustomOp): + # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110 + def __init__(self, alpha: float = 1.702, limit: float = 7.0): + super().__init__() + self.alpha = alpha + self.limit = limit + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + + gate, up = x[..., ::2], x[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + return gated_output + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit) + return out + + def extra_repr(self) -> str: + return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}" + + @CustomOp.register("gelu_new") class NewGELU(CustomOp): @@ -330,6 +359,7 @@ class ReLUSquaredActivation(CustomOp): return torch.square(F.relu(x)) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + #TODO : implement cuda kenrels return self.forward_native(x) @@ -392,12 +422,23 @@ _ACTIVATION_REGISTRY = LazyDict({ lambda: nn.SiLU(), "quick_gelu": lambda: QuickGELU(), + "tanh": + lambda: nn.Tanh(), + "sigmoid": + lambda: nn.Sigmoid(), }) def get_act_fn(act_fn_name: str) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() + + if act_fn_name.startswith("torch.nn.modules."): + activation_name = act_fn_name.split(".")[-1] + if activation_name == "identity": + return nn.Identity() + act_fn_name = activation_name + if act_fn_name not in _ACTIVATION_REGISTRY: raise ValueError( f"Activation function {act_fn_name!r} is not supported.") @@ -406,9 +447,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module: _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ - "gelu": lambda: GeluAndMul(), - "silu": lambda: SiluAndMul(), - "geglu": lambda: GeluAndMul(), + "gelu": + lambda: GeluAndMul(), + "silu": + lambda: SiluAndMul(), + "geglu": + lambda: GeluAndMul(), + "swigluoai": + lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), }) diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py new file mode 100644 index 0000000000000..782818f55fbc2 --- /dev/null +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for attention-like layers.""" +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class AttentionLayerBase(ABC): + """ + Base class for attention-like layers (Attention, Mamba, etc.) + that support the v1 engine. + + This provides a common interface for getting attention backends + from different layer types. + """ + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this layer.""" + pass diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3d40879b4ccbf..3007643d7a288 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -49,7 +49,8 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8) + CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4, + cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( @@ -69,6 +70,7 @@ if HAS_TRITON: "cutlass_moe_fp8", "cutlass_moe_fp4", "CutlassExpertsFp8", + "CutlassBatchedExpertsFp8", "TritonExperts", "BatchedTritonExperts", "DeepGemmExperts", diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index c48a0137c3060..a5326dfe84f6d 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, - is_blackwell_deep_gemm_e8m0_used) + is_deep_gemm_e8m0_used) logger = init_logger(__name__) @@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm( # number of valid tokens for this expert n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) - cols = tl.arange(0, BLOCK) - cols = cols.to(tl.int64) - mask_h = cols < BLOCK + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + + cols * stride_yq_h) + base_ys_offset = e * stride_ys_e + g * stride_ys_g for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): - base_i_offset = (e * stride_i_e + t * stride_i_t + - g * GROUP_SIZE * stride_i_h) - base_yq_offset = (e * stride_yq_e + t * stride_yq_t + - g * GROUP_SIZE * stride_yq_h) - base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g - - mask = mask_h - x = tl.load(input_ptr + base_i_offset + cols * stride_i_h, - mask=mask, - other=0.0).to(tl.float32) - y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h + - cols * stride_i_h, + gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t, + mask=mask, + other=0.0).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, - other=0.0).to(tl.float32) + other=0.0) - x = x * (1.0 / (1.0 + tl.exp(-x))) - y = x * y2 + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) - _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - scale_raw = _absmax / fp8_max - y_s = tl.math.exp2(tl.ceil( - tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) - tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) - tl.store(y_s_ptr + base_ys_offset, y_s) + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) def silu_mul_fp8_quant_deep_gemm( - y: torch.Tensor, # (E, T, 2*H) float32 + y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert group_size: int = 128, eps: float = 1e-10, -): +) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. Returns `(y_q, y_s)` where - * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`. - * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)` + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) """ assert y.ndim == 3, "y must be (E, T, 2*H)" E, T, H2 = y.shape @@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm( stride_cnt_e = tokens_per_expert.stride()[0] - # static grid over experts and H-groups. + # Static grid over experts and H-groups. # A loop inside the kernel handles the token dim grid = (E * G, ) @@ -176,9 +174,9 @@ def silu_mul_fp8_quant_deep_gemm( eps, fp8_min, fp8_max, - is_blackwell_deep_gemm_e8m0_used(), + is_deep_gemm_e8m0_used(), BLOCK=group_size, - NUM_STAGES=8, + NUM_STAGES=4, num_warps=1, ) @@ -254,18 +252,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output, a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index fc30e84e6656d..89d7412ee2236 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -132,18 +132,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): a, aq, M, N, K, topk, global_num_experts, local_num_experts, expert_tokens_metadata) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None @@ -151,4 +161,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, workspace2, expert_tokens_meta, - apply_router_weight_on_input, extra_expert_args) + apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 31ea826f1f97a..7c1a7b636a9c2 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -45,7 +45,6 @@ def get_quant_config_weight_quant( return _get_quant_config_quantization_args(quant_config, "weights") -# TODO (bnell): use scalar_type instead of bools? def get_config_quant_dtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, @@ -65,7 +64,8 @@ def get_config_quant_dtype( @dataclass class FusedMoEQuantConfig: # The post quantization activation type. - quant_dtype: Optional[torch.dtype] = None + # TODO (bnell): use scalar_type instead of Union. + quant_dtype: Union[torch.dtype, str, None] = None per_act_token_quant: bool = False per_out_ch_quant: bool = False block_shape: Optional[list[int]] = None @@ -141,6 +141,7 @@ class FusedMoEQuantConfig: use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + use_mxfp4_w4a4, ] ]) <= 1, "Quantization flags are mutually exclusive." @@ -334,7 +335,7 @@ class FusedMoEConfig: assert self.max_num_tokens > 0 @property - def quant_dtype(self) -> Optional[torch.dtype]: + def quant_dtype(self) -> Union[torch.dtype, str, None]: if self.quant_config is not None: return self.quant_config.quant_dtype else: @@ -429,7 +430,7 @@ class FusedMoEConfig: block_shape = None per_act_token_quant = False per_out_ch_quant = False - quant_dtype: Optional[torch.dtype] = None + quant_dtype: Union[torch.dtype, str, None] = None input_quant = get_quant_config_input_quant(quant_config) weight_quant = get_quant_config_weight_quant(quant_config) @@ -453,7 +454,7 @@ class FusedMoEConfig: ModelOptNvFp4Config) if quant_dtype is None and isinstance(quant_config, ModelOptNvFp4Config): - quant_dtype = torch.uint8 + quant_dtype = "nvfp4" if weight_quant is not None: per_out_ch_quant = ( diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..63de4bfa4cb52 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,122 @@ +{ + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..b962d19506ce5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..6efcc02b4d9a2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,114 @@ +{ + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..d677d69c57a25 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,154 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index e67ff66882102..769a04b7de89d 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -3,10 +3,110 @@ from typing import Callable, Optional import torch +from torch.nn import functional as F from vllm import envs +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + gating_output = gating_output.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids.to(torch.int32) + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + return grouped_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + elif custom_routing_function is None: + assert scoring_func == "softmax" + topk_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids.to(torch.int32) + else: + return custom_routing_function(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + class IPEXFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: @@ -56,113 +156,6 @@ class SGLFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: pass - @staticmethod - def _grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - - gating_output = gating_output.float() - if scoring_func == "softmax": - scores = torch.softmax(gating_output, dim=-1) - elif scoring_func == "sigmoid": - scores = gating_output.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - num_token = scores.shape[0] - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use - # biased scores for expert selection but original scores for - # routing weights - original_scores = scores - scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) - else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, - k=topk_group, - dim=-1, - sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, - -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] - - if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_scores.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, - keepdim=True) - - return topk_weights, topk_ids.to(torch.int32) - - @staticmethod - def _select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = SGLFusedMOE._grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - elif custom_routing_function is None: - assert scoring_func == "softmax" - topk_weights = torch.nn.functional.softmax(router_logits, - dim=1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) - if renormalize: - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) - topk_ids = topk_ids.to(torch.int32) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) - - return topk_weights, topk_ids - def __call__( self, layer: torch.nn.Module, @@ -183,7 +176,7 @@ class SGLFusedMOE: ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input - topk_weights, topk_ids = SGLFusedMOE._select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -213,3 +206,80 @@ class SGLFusedMOE: True, ) return x + + +class CPUFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + pass + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 + len_experts = global_num_experts + + cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) + cnts.scatter_(1, topk_ids.to(torch.int64), 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + + layer_w13_weight = layer.w13_weight[i] + layer_w2_weight = layer.w2_weight[i] + + gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) + gate_up = silu_and_mul(gate_up) + expert_out = F.linear(gate_up, layer_w2_weight) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, + dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + + new_x[idxs] = outs + final_out = (new_x.view( + *topk_ids.shape, -1).type(topk_weights.dtype).mul_( + topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype)) + return final_out diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 2585a2953c9db..95d23ec0346c1 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ CUTLASS based Fused MoE kernels.""" -from typing import Any, Callable, Optional +from typing import Callable, Optional import torch @@ -9,14 +9,14 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, moe_unpermute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, - _resize_cache, - extract_required_args) + TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, + _resize_cache) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -35,6 +35,10 @@ def run_cutlass_moe_fp8( w2_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], @@ -42,6 +46,7 @@ def run_cutlass_moe_fp8( per_act_token: bool, per_out_ch: bool, use_batched_format: bool, + topk_weights: Optional[torch.Tensor], ): a1q = hidden_states @@ -100,6 +105,22 @@ def run_cutlass_moe_fp8( topk = local_topk_ids.size(1) local_E = w1.size(0) + if use_batched_format: + mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2)) + act_out = _resize_cache(workspace2, (local_E * padded_M, N)) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (local_E * padded_M, N)) + mm2_out = _resize_cache(workspace2, (local_E * padded_M, K)) + else: + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), + (M * topk, K)) + mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) + act_out = _resize_cache(workspace2, (M * topk, N)) + # original workspace are based on input hidden_states dtype (bf16) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (M * topk, N)) + mm2_out = _resize_cache(workspace2, (M * topk, K)) + if use_batched_format: assert expert_num_tokens is not None @@ -121,11 +142,10 @@ def run_cutlass_moe_fp8( w2_scale = w2_scale.reshape(w2_scale.size(0), -1) a1q = a1q.reshape(-1, a1q.size(2)) a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() - + # c3x get_group_gemm_starts expects int64 to avoid overflow + # during offset calculations + expert_offsets = expert_offsets.to(torch.int64) else: - expert_offsets = torch.empty((global_num_experts + 1), - dtype=torch.int32, - device=device) problem_sizes1 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) @@ -133,99 +153,71 @@ def run_cutlass_moe_fp8( dtype=torch.int32, device=device) - # With expert_map each Rank processes only a subset of experts. As - # a result not all of a_map and c2 tensors are filled. We fill it - # zeros for correctness. - if expert_map is not None: - a_map = torch.zeros((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - else: - a_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - c_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, a_map, - c_map, global_num_experts, N, K) - - a1q = _fp8_perm(a1q, a_map) - a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale + num_expert = global_num_experts if expert_map is None \ + else expert_map.size(0) + # permuted a1q reuses workspace2 + a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( + a1q, + a1q_scale, + topk_ids, + num_expert, + local_E, + expert_map, + permuted_hidden_states=a1q_perm) expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - c_strides1 = torch.full((w1.size(0), ), - 2 * N, - device=device, - dtype=torch.int64) - ab_strides2 = torch.full((w1.size(0), ), - N, - device=device, - dtype=torch.int64) - c_strides2 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - - if use_batched_format: - c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) - c2 = _resize_cache(workspace2, (local_E * padded_M, N)) - c3 = _resize_cache(workspace13, (local_E * padded_M, K)) - else: - c1 = _resize_cache(workspace13, (M * topk, N * 2)) - c2 = _resize_cache(workspace2, (M * topk, N)) - c3 = _resize_cache(workspace13, (M * topk, K)) + ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1, + problem_sizes2, + global_num_experts, N, K) if not per_act_token and (expert_map is not None or use_batched_format): # this is necessary to avoid imprecise scale calculation caused by # random data in the unused workspace. The workspace is unused when # this rank handles only partial tokens, or when it is batched . - c1.fill_(0) + mm1_out.fill_(0) - ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, + ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets, problem_sizes1, ab_strides1, ab_strides1, c_strides1, per_act_token, per_out_ch) - activation_callable(c2, c1) + activation_callable(act_out, mm1_out) a2q, a2q_scale = ops.scaled_fp8_quant( - c2, a2_scale, use_per_token_if_dynamic=per_act_token) + act_out, + a2_scale, + use_per_token_if_dynamic=per_act_token, + output=quant_out) if expert_map is not None: - c3.fill_(0) + mm2_out.fill_(0) - ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets, + ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets, problem_sizes2, ab_strides2, ab_strides2, c_strides2, per_act_token, per_out_ch) if use_batched_format: - output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) + output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True) else: - # We can't do this inplace because output may point to the same tensor - # as c3. - output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) + # for non-chunking mode the output is resized from workspace13 + # so we need to make sure mm2_out uses workspace2. + moe_unpermute(out=output, + permuted_hidden_states=mm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm) -# TODO (bnell): split class batched vs. non-batched? -# maybe remove need for passing aq to workspace_shapes -class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_experts_per_worker: int, out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, - num_dispatchers: Optional[int] = None, - use_batched_format: bool = False, ): super().__init__( FusedMoEQuantConfig( @@ -234,33 +226,101 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, )) - assert max_experts_per_worker > 0 - assert not use_batched_format or num_dispatchers is not None - self.max_experts_per_worker = max_experts_per_worker - self.num_dispatchers = num_dispatchers self.out_dtype = out_dtype - self.use_batched_format = use_batched_format + self.ab_strides1 = ab_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides1 = c_strides1 + self.c_strides2 = c_strides2 + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + + expert_num_tokens = None + if expert_tokens_meta is not None: + expert_num_tokens = expert_tokens_meta.expert_num_tokens + + activation_callable = lambda o, i: self.activation(activation, o, i) + + use_batched_format = self.activation_formats[ + 0] == mk.FusedMoEActivationFormat.BatchedExperts + + in_dtype = hidden_states.dtype + run_cutlass_moe_fp8( + output, hidden_states, w1, w2, topk_ids, activation_callable, + global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, + self.c_strides2, workspace13, workspace2, expert_num_tokens, + self.out_dtype if self.out_dtype is not None else in_dtype, + self.per_act_token_quant, self.per_out_ch_quant, + use_batched_format, topk_weights) + + +class CutlassExpertsFp8(CutlassExpertsFp8Base): + + def __init__( + self, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + block_shape: Optional[list[int]] = None, + ): + super().__init__( + out_dtype, + per_act_token_quant, + per_out_ch_quant, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + block_shape, + ) @property def activation_formats( self ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.use_batched_format: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) - else: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: - return not self.use_batched_format + return True def supports_expert_map(self) -> bool: - return not self.use_batched_format + return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + # topk weights and reduction are fused in moe_unpermute cuda kernel + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -274,54 +334,78 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - workspace1: tuple[int, ...] = () - workspace2: tuple[int, ...] = () - output: tuple[int, ...] = () - if self.use_batched_format: - padded_M = aq.size(1) - num_dp = self.num_dispatchers - assert num_dp is not None - workspace1 = (self.max_experts_per_worker, padded_M * num_dp, - max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, - (N // 2)) - output = (self.max_experts_per_worker, padded_M, K) - else: - workspace1 = (M * topk, max(N, K)) - workspace2 = (M * topk, N // 2) - output = (M * topk, K) + workspace1 = (M * topk, max(N, K)) + workspace2 = (M * topk, max(N // 2, K)) + output = (M, K) return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): - assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" - assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" - expert_num_tokens = None - if expert_tokens_meta is not None: - expert_num_tokens = expert_tokens_meta.expert_num_tokens +class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): - activation_callable = lambda o, i: self.activation(activation, o, i) + def __init__( + self, + max_experts_per_worker: int, + num_dispatchers: int, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + block_shape: Optional[list[int]] = None, + ): + super().__init__( + out_dtype, + per_act_token_quant, + per_out_ch_quant, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + block_shape, + ) + assert max_experts_per_worker > 0 + self.max_experts_per_worker = max_experts_per_worker + self.num_dispatchers = num_dispatchers - in_dtype = hidden_states.dtype - run_cutlass_moe_fp8( - output, hidden_states, w1, w2, topk_ids, activation_callable, - global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, expert_num_tokens, - self.out_dtype if self.out_dtype is not None else in_dtype, - self.per_act_token_quant, self.per_out_ch_quant, - self.use_batched_format) + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + # TODO(bnell): maybe remove need for passing aq to workspace_shapes + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + padded_M = aq.size(1) + num_dp = self.num_dispatchers + assert num_dp is not None + workspace1 = (self.max_experts_per_worker, padded_M * num_dp, + max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M * num_dp, + max(N // 2, K)) + output = (self.max_experts_per_worker, padded_M, K) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) def cutlass_moe_fp8( @@ -332,6 +416,10 @@ def cutlass_moe_fp8( topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, per_act_token: Optional[bool] = None, activation: str = "silu", a1_scale: Optional[torch.Tensor] = None, @@ -359,6 +447,17 @@ def cutlass_moe_fp8( Shape: [num_experts] or [num_experts, 2N] - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. Shape: [num_experts] or [num_experts, K] + - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm. + Shape: [num_experts] + - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm. + Shape: [num_experts] + - c_strides1 (torch.Tensor): The output strides for the first gemm. + Shape: [num_experts] + - c_strides2 (torch.Tensor): The output strides for the second gemm. + Shape: [num_experts] + - per_act_token (Optional[bool]): Whether the scale is per-token or + per-tensor. + - activation (str): The activation function to use. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. Shape: scalar or [M] - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to @@ -387,11 +486,13 @@ def cutlass_moe_fp8( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - max_experts_per_worker=num_experts, out_dtype=a.dtype, per_act_token_quant=per_act_token, per_out_ch_quant=per_out_ch, - use_batched_format=False, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, ), ) @@ -476,8 +577,9 @@ def run_cutlass_moe_fp4( e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape - assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", - " between weights.") + assert (e_w1 == e_w2 + and e_w1 == e), ("Number of experts must match", + f" between weights. {e_w1}, {e_w2}, {e}") assert (k_a == half_k_w1 * 2 and k == k_w2), ("Hidden size mismatch between a, w1 and w2") assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " @@ -554,6 +656,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, max_experts_per_worker: int, out_dtype: torch.dtype, per_act_token_quant: bool, @@ -562,8 +668,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): use_batched_format: bool = False, ): super().__init__( + # NVFP4 requires two levels of quantization, which involves + # computing some scaling factors dynamically. This makes it + # incompatible with the typical prepare -> MoE -> finalize + # pipeline. Move the quantization logic into the MoE body. FusedMoEQuantConfig( - quant_dtype=torch.uint8, + quant_dtype=None, # skip quantization in prepare/finalize per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, @@ -572,6 +682,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): self.out_dtype = out_dtype self.use_batched_format = use_batched_format + # TODO(bnell): put this stuff into quant config? + self.g1_alphas = g1_alphas + self.g2_alphas = g2_alphas + self.a1_gscale = a1_gscale + self.a2_gscale = a2_gscale + @property def activation_formats( self @@ -590,8 +706,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -620,34 +735,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor, - w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], - workspace2: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): - required_keys = [ - "g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k", - "e", "device" - ] - (g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e, - device) = extract_required_args(extra_expert_args, required_keys) + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: torch.Tensor, + workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids) + n = w2.shape[2] * 2 + run_cutlass_moe_fp4( output=output, a=hidden_states, - a1_gscale=a1_gscale, + a1_gscale=self.a1_gscale, w1_fp4=w1, w1_blockscale=w1_scale, - w1_alphas=g1_alphas, - a2_gscale=a2_gscale, + w1_alphas=self.g1_alphas, + a2_gscale=self.a2_gscale, w2_fp4=w2, w2_blockscale=w2_scale, - w2_alphas=g2_alphas, + w2_alphas=self.g2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, workspace13=workspace13, @@ -656,7 +779,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): n=n, k=k, e=e, - device=device, + device=hidden_states.device, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -677,7 +800,6 @@ def cutlass_moe_fp4( n: int, k: int, e: int, - device: torch.device, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False) -> torch.Tensor: assert expert_map is None, ("Expert Parallelism / expert_map " @@ -686,6 +808,10 @@ def cutlass_moe_fp4( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( + g1_alphas, + g2_alphas, + a1_gscale, + a2_gscale, max_experts_per_worker=e, out_dtype=a.dtype, per_act_token_quant=False, @@ -693,29 +819,7 @@ def cutlass_moe_fp4( use_batched_format=False, ), ) - extra_expert_args = { - 'g1_alphas': g1_alphas, - 'g2_alphas': g2_alphas, - 'a1_gscale': a1_gscale, - 'a2_gscale': a2_gscale, - 'm': m, - 'n': n, - 'k': k, - 'e': e, - 'device': device, - } - # NVFP4 requires two levels of quantization, which involves computing some - # scaling factors dynamically. This makes it incompatible with the typical - # prepare -> MoE -> finalize pipeline. Move the quantization logic into the - # MoE body. - extra_prepare_args = { - 'skip_quant': True, - } - # Similar reason as above. - extra_finalize_args = { - 'skip_weight_reduce': True, - } return fn( hidden_states=a, w1=w1_fp4, @@ -731,9 +835,6 @@ def cutlass_moe_fp4( a1_scale=None, a2_scale=None, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args, - extra_prepare_args=extra_prepare_args, - extra_finalize_args=extra_finalize_args, ) @@ -824,16 +925,6 @@ def run_cutlass_block_scaled_fused_experts( k = w1_q.size(1) n = w2_q.size(1) - expert_offsets = torch.empty((num_experts + 1, ), - dtype=torch.int32, - device="cuda") - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device="cuda") - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device="cuda") - topk = topk_ids.size(1) a_q, a1_scale = _fp8_quantize(a, @@ -842,6 +933,16 @@ def run_cutlass_block_scaled_fused_experts( block_shape=[128, 128]) device = a_q.device + expert_offsets = torch.empty((num_experts + 1, ), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 9b8175f42a9d2..7b8467a5a0cf0 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Any, Optional +from typing import Optional import torch from tqdm import tqdm @@ -230,7 +230,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): assert self.block_shape is not None assert a1q_scale is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index f6b62254e7b4c..437e569d3130d 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import deep_ep import torch @@ -127,12 +127,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_topk_weights) def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -187,11 +191,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert self.handle is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index cfc2bdcf02408..93ac11fb4bfbf 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union +from typing import Optional, Union import deep_ep import torch @@ -77,7 +77,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], a1_dtype: torch.dtype, - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_act_token_quant: bool, block_shape: Optional[list[int]], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -111,12 +111,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return x, x_scales def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -162,11 +166,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return (expert_x, expert_x_scale, expert_tokens_meta, None, None) - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 4e3e15a35ada2..feab3f74cac53 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional, Union import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import extract_required_args + TopKWeightAndReduceNoOP) from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe) @@ -20,7 +21,7 @@ def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> bool: """ - Check if the given problem size is supported by the FlashInfer CUTLASS MoE + Check if the given problem size is supported by the FlashInfer CUTLASS MoE kernel. """ if not has_flashinfer_cutlass_fused_moe(): @@ -43,31 +44,34 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_nvfp4_w4a4: bool = False, - use_fp8_w8a8: bool = False, - use_dp: bool = False, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + out_dtype: torch.dtype, + quant_dtype: Union[torch.dtype, str, None], ep_rank: int = 0, ep_size: int = 1, tp_rank: int = 0, tp_size: int = 1, - num_dispatchers: Optional[int] = None, - use_batched_format: bool = False, ): super().__init__( FusedMoEQuantConfig( - quant_dtype=torch.uint8, + quant_dtype=quant_dtype, per_act_token_quant=False, block_shape=None, )) - self.use_nvfp4_w4a4 = use_nvfp4_w4a4 - self.use_fp8_w8a8 = use_fp8_w8a8 + assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( + "Only nvfp4,fp8 quantization are currently supported.") self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank self.tp_size = tp_size - self.use_dp = use_dp - assert not use_batched_format or num_dispatchers is not None - self.num_dispatchers = num_dispatchers + self.g1_alphas = g1_alphas + self.g2_alphas = g2_alphas + self.a1_gscale = a1_gscale + self.a2_gscale = a2_gscale + self.out_dtype = out_dtype @property def activation_formats( @@ -84,8 +88,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -117,11 +120,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " - "currently supported.") aq_m, aq_n = aq.shape workspace2 = () - output_shape = (aq_m, aq_n * 2) + output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \ + torch.float8_e4m3fn else (aq_m, aq_n) workspace_dtype = a.dtype workspace1 = output_shape # The workspace is determined by `aq`, since it comes after any @@ -149,43 +151,41 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], - extra_expert_args: Optional[dict[str, Any]], ): - assert extra_expert_args is not None, \ - "extra_expert_args must be provided" - required_keys = [ - 'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype' - ] + if self.quant_dtype == torch.float8_e4m3fn: + quant_scales = [ + self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale + ] - g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = ( - extract_required_args(extra_expert_args, required_keys)) + a1q_scale = None # not passing input_sf in fp8 + fc1_expert_weights = w1 + fc2_expert_weights = w2 + else: + # Ensure w1_scale and w2_scale are not None before calling view + assert w1_scale is not None and w2_scale is not None, ( + "w1_scale and w2_scale must not " + "be None for FlashInferExperts") + # Flashinfer CUTLASS kernel takes scalar global scales, + # min because inv_scale. + quant_scales = [ + self.a1_gscale, + w1_scale.view(torch.int32), + self.g1_alphas, + self.a2_gscale, + w2_scale.view(torch.int32), + self.g2_alphas, + ] + # FlashInfer API requires weight to be long for nvfp4 + fc1_expert_weights = w1.view(torch.long) + fc2_expert_weights = w2.view(torch.long) - # Flashinfer CUTLASS kernel takes scalar global scales, - # min because inv_scale. - assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " - "currently supported.") - - # Ensure w1_scale and w2_scale are not None before calling view - assert w1_scale is not None and w2_scale is not None, ( - "w1_scale and w2_scale must not " - "be None for FlashInferExperts") - - quant_scales = [ - a1_gscale, - w1_scale.view(torch.int32), - g1_alphas, - a2_gscale, - w2_scale.view(torch.int32), - g2_alphas, - ] _ = flashinfer_cutlass_fused_moe( input=hidden_states, token_selected_experts=topk_ids.to(torch.int), token_final_scales=topk_weights, - # FlashInfer API requires weight to be long for nvfp4 - fc1_expert_weights=w1.view(torch.long), - fc2_expert_weights=w2.view(torch.long), - output_dtype=out_dtype, + fc1_expert_weights=fc1_expert_weights, + fc2_expert_weights=fc2_expert_weights, + output_dtype=self.out_dtype, quant_scales=quant_scales, input_sf=a1q_scale, tp_size=self.tp_size, @@ -194,3 +194,50 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ep_rank=self.ep_rank, output=output, ) + + +def flashinfer_cutlass_moe_fp4( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + + fused_experts = mk.FusedMoEModularKernel( + FlashInferCutlassMoEPrepareAndFinalize(use_dp=False, + a1_gscale=a1_gscale), + FlashInferExperts( + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + out_dtype=hidden_states.dtype, + quant_dtype="nvfp4", + )) + + return fused_experts( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 36aca8cf74b6d..061b02172c446 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -9,7 +9,7 @@ from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( - extract_required_args, moe_kernel_quantize_input) + moe_kernel_quantize_input) from vllm.utils.flashinfer import nvfp4_block_scale_interleave @@ -21,16 +21,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( self, - quant_dtype: Optional[torch.dtype] = None, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, + use_dp: bool, + a1_gscale: Optional[torch.Tensor], num_dispatchers: int = 1, ): super().__init__() - self.per_channel_quant = per_channel_quant - self.block_shape = block_shape - self.quant_dtype = quant_dtype self.num_dispatchers_ = num_dispatchers + self.use_dp = use_dp + self.a1_gscale = a1_gscale + self.local_tokens = None @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -55,10 +54,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + # TODO(bnell): use quant_config + scales instead of ctor args quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -67,22 +67,22 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - (a1_gscale, use_dp, local_tokens) = extract_required_args( - extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens']) - a1q, a1q_scale = moe_kernel_quantize_input( a1, - a1_gscale, + self.a1_gscale, quant_config.quant_dtype, - self.per_channel_quant, - self.block_shape, - is_fp4_scale_swizzled=not use_dp, # Swizzling after communication + quant_config.per_act_token_quant, + quant_config.block_shape, + # Swizzling after communication + is_fp4_scale_swizzled=not self.use_dp, ) - if use_dp: + if self.use_dp: topk_weights, topk_ids, a1q, a1q_scale = \ - get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501 - dim=0, - sizes=get_local_sizes()) + get_dp_group().all_gatherv( + [topk_weights, topk_ids, a1q, a1q_scale], + dim=0, + sizes=get_local_sizes(), + ) a1_m, a1_n = a1q.shape a1q_scale = nvfp4_block_scale_interleave(a1q_scale) @@ -91,13 +91,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: - (use_dp, - local_tokens) = extract_required_args(extra_finalize_args, - ['use_dp', 'local_tokens']) - if use_dp: + if self.use_dp: fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output, dim=0, sizes=get_local_sizes()) output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9a5c85e120cc1..b46f4be4b912e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" -from typing import Any, Optional +from typing import Optional import torch @@ -496,12 +496,16 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return self.num_dispatchers_ def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -590,11 +594,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return b_a1, b_a1_scale, expert_tokens_meta, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) weight_and_reduce_impl.apply( @@ -688,18 +696,28 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): else: return t.to(f32) * group_broadcast(scale, t.shape) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): assert hidden_states.dim() == 3 assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens @@ -894,18 +912,28 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (num_experts, max_num_tokens * num_dp, K) return (workspace13, workspace2, output, a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index a49d41c18438e..1e3ac6cd79f68 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE utilities for GPTQ.""" -import functools from typing import Optional import torch import vllm._custom_ops as ops -from vllm.model_executor.layers.fused_moe.fused_moe import ( - moe_align_block_size, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, maybe_warn_marlin_atomic_add) from vllm.scalar_type import ScalarType, scalar_types @@ -98,17 +96,11 @@ def fused_marlin_moe(hidden_states: torch.Tensor, N = w2.shape[1] * 16 topk = topk_ids.shape[1] - get_config_func = functools.partial( - try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - None, - is_marlin=True, - ) - config = get_config_func(M) - - block_size_m = config["BLOCK_SIZE_M"] + # M block size selection logic + # TODO: tune this further for specific models + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break if global_num_experts == -1: global_num_experts = E @@ -169,25 +161,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor, if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - elif activation == "swiglu_oai": - # NOTE: in gpt-oss, the gate_proj and up_proj is interleaved - # - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2] - # - origin: gate, up = gate_up[..., :N], gate_up[..., N:] - - @torch.compile(dynamic=True) - def swiglu_oai(gate_up): - alpha = 1.702 - limit = 7.0 - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate * alpha) - return (up + 1) * glu - - intermediate_cache2 = swiglu_oai(intermediate_cache1) + elif activation == "swigluoai": + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, 2 * N)) else: raise ValueError(f"Unsupported activation: {activation}. " - "Only silu and swiglu_oai activations are supported.") + "Only silu and swigluoai activations are supported.") if expert_map is not None: intermediate_cache3.zero_() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1c497fa5521b9..17a5c735a57fe 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -801,7 +801,6 @@ def get_default_config( K: int, topk: int, dtype: Optional[str], - is_marlin: bool, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: @@ -832,11 +831,6 @@ def get_default_config( config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} else: config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} - elif is_marlin: - for block_size_m in [8, 16, 32, 48, 64]: - if M * topk / E / block_size_m < 0.9: - break - return {"BLOCK_SIZE_M": block_size_m} elif M <= E: config = { "BLOCK_SIZE_M": 16, @@ -860,7 +854,6 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - is_marlin: bool = False, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config @@ -883,7 +876,7 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - is_marlin, block_shape) + block_shape) return config @@ -956,8 +949,23 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \ + current_platform.is_cuda() and \ + num_expert_group <= 32 and topk <= 32 and \ + e_score_correction_bias is not None: + return fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor) assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") @@ -1003,9 +1011,38 @@ def grouped_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +def fused_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + e_score_correction_bias: torch.Tensor, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + topk_values, topk_indices = ops.grouped_topk( + scores, scores_with_bias.to(scores.dtype), num_expert_group, + topk_group, topk, renormalize, routed_scaling_factor) + return topk_values.to(torch.float32), topk_indices.to(torch.int32) + + def get_config_dtype_str( dtype: torch.dtype, use_int4_w4a16: Optional[bool] = False, @@ -1394,9 +1431,8 @@ def fused_experts(hidden_states: torch.Tensor, # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used( - ) or _valid_deep_gemm(hidden_states, w1, w2) - if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): + if (allow_deep_gemm and use_fp8_w8a8 and + (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): assert apply_router_weight_on_input is False assert is_act_and_mul, ( "DeepGemm only supports is_act_and_mul=True for now.") @@ -1628,17 +1664,6 @@ def fused_experts_impl( block_shape=block_shape, B_bias=w1_bias) - # TODO fused kernel - def swiglu_oai(gate_up): - alpha = 1.702 - limit = 7.0 - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate * alpha) - gated_output = (up + 1) * glu - return gated_output - # Activation function with multiplication if activation == "silu" and is_act_and_mul: torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1646,13 +1671,16 @@ def fused_experts_impl( elif activation == "gelu" and is_act_and_mul: torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + elif activation == "swigluoai" and is_act_and_mul: + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) # Activation function without multiplication elif activation == "silu": intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == "gelu": intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) - elif activation == "swiglu_oai": - intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N)) + else: raise ValueError(f"Unsupported FusedMoe activation: {activation}, " f"with is_act_and_mul={is_act_and_mul}.") @@ -1905,7 +1933,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): # Check constraints. if self.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 6b5284dc6c96c..312befe2c1d71 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional import torch @@ -8,7 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import extract_required_args from vllm.utils import has_triton_kernels logger = init_logger(__name__) @@ -160,12 +159,16 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): num_dispatchers: int, w1_precision: "PrecisionConfig", w2_precision: "PrecisionConfig", + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], ): super().__init__(quant_config) self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers self.w1_precision = w1_precision self.w2_precision = w2_precision + self.w1_bias = w1_bias + self.w2_bias = w2_bias @property def activation_formats( @@ -219,11 +222,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): - w1_bias, w2_bias = (extract_required_args(extra_expert_args, - ["w1_bias", "w2_bias"])) - return triton_kernel_fused_experts( output, hidden_states, @@ -240,8 +239,8 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, - w1_bias=w1_bias, - w2_bias=w2_bias, + w1_bias=self.w1_bias, + w2_bias=self.w2_bias, w1_precision=self.w1_precision, w2_precision=self.w2_precision, a1_scale=a1q_scale, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 36e75825853e6..54406a5a2d87f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -37,7 +37,6 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, round_up) -from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -49,9 +48,6 @@ if current_platform.is_cuda_alike(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) - if has_flashinfer(): - from .flashinfer_cutlass_prepare_finalize import ( - FlashInferCutlassMoEPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -80,7 +76,12 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - moe: FusedMoEConfig + # TODO(bnell): also pass quant_config? + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.moe = moe + self.fused_experts: Optional[Callable] = None + self.topk_indices_dtype = None @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -99,16 +100,16 @@ class FusedMoEMethodBase(QuantizeMethodBase): return False @staticmethod - def maybe_make_prepare_finalize( - moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]: + def _maybe_make_prepare_finalize( + moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None - if moe.use_flashinfer_cutlass_kernels: - prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( - quant_dtype=moe.quant_dtype, ) + assert not moe.use_flashinfer_cutlass_kernels, \ + "Must be created in modelopt.py" + if moe.use_pplx_kernels: hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, @@ -188,14 +189,27 @@ class FusedMoEMethodBase(QuantizeMethodBase): return prepare_finalize - def init_prepare_finalize(self, moe: FusedMoEConfig): - self.moe = moe - prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize( - self.moe) + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[FusedMoEPrepareAndFinalize]: + if moe.moe_parallel_config.use_all2all_kernels: + return FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + else: + return None + + # Note: init_prepare_finalize should only be called by + # prepare_communication_buffer_for_model. + def init_prepare_finalize(self): + assert self.moe is not None + prepare_finalize = self.maybe_make_prepare_finalize(self.moe) - self.topk_indices_dtype = None if prepare_finalize is not None: - logger.debug("%s", prepare_finalize.__class__.__name__) + logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__, + self, id(self)) + assert self.topk_indices_dtype is None + assert self.fused_experts is None, \ + f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, self.moe) self.fused_experts = FusedMoEModularKernel( @@ -214,12 +228,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") - def maybe_swap_experts_impl( - self, - moe_parallel_config: FusedMoEParallelConfig, - ): - pass - @abstractmethod def apply( self, @@ -251,10 +259,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: FusedMoEConfig): - super().__init__() - self.fused_experts = fused_experts # type: ignore - self.topk_indices_dtype = None - self.moe = moe + super().__init__(moe) self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: @@ -266,6 +271,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, + # TODO(bnell): Remove. Every layer should have an moe config object. moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: if (prepare_finalize.activation_format == @@ -352,12 +358,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): use_prepack=True, ) elif current_platform.is_cpu(): + from vllm.model_executor.layers.fused_moe import cpu_fused_moe if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - from vllm.model_executor.layers.fused_moe import cpu_fused_moe - dtype = layer.w13_weight.dtype + from vllm.model_executor.layers.utils import ( + check_cpu_sgl_kernel) + dtype_w13 = layer.w13_weight.dtype + _, n_w13, k_w13 = layer.w13_weight.size() + dtype_w2 = layer.w2_weight.dtype + _, n_w2, k_w2 = layer.w2_weight.size() if (envs.VLLM_CPU_SGL_KERNEL - and torch._C._cpu._is_amx_tile_supported() - and dtype == torch.bfloat16): + and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) + and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)): packed_w13_weight = torch.ops._C.convert_weight_packed( layer.w13_weight) assert packed_w13_weight.size() == layer.w13_weight.size() @@ -371,7 +382,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else: layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: - raise NotImplementedError("CPU MOE only supports x86 arch.") + layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) def apply( self, @@ -474,9 +485,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - else: - # add w1_bias/w2_bias to kwargs if they exist - kwargs = dict( + elif self.fused_experts is not None: + if self.has_bias: + raise ValueError( + "FusedMoEModularKernel does not support bias.") + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -488,17 +501,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): global_num_experts=global_num_experts, expert_map=expert_map, ) - if isinstance(self.fused_experts, - FusedMoEModularKernel) and self.has_bias: - raise ValueError( - "FusedMoEModularKernel does not support bias.") - if self.has_bias: - kwargs.update({ - "w1_bias": getattr(layer, "w13_bias", None), - "w2_bias": getattr(layer, "w2_bias", None), - }) - - return self.fused_experts(**kwargs) + else: + assert fused_experts is not None + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_bias=layer.w13_bias if self.has_bias else None, + w2_bias=layer.w2_bias if self.has_bias else None, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) def forward_cpu( self, @@ -682,6 +700,26 @@ def determine_expert_map( return (local_num_experts, expert_map) +def get_compressed_expert_map(expert_map: torch.Tensor) -> str: + """ + Compresses the expert map by removing any -1 entries. + + Args: + expert_map (torch.Tensor): A tensor of shape (global_num_experts,) + mapping from global to local index. Contains -1 for experts not + assigned to the current rank. + + Returns: + str: A string mapping from local to global index. + Using str to support hashing for logging once only. + """ + global_indices = torch.where(expert_map != -1)[0] + local_indices = expert_map[global_indices] + return ", ".join( + f"{local_index.item()}->{global_index.item()}" + for local_index, global_index in zip(local_indices, global_indices)) + + @CustomOp.register("fused_moe") class FusedMoE(CustomOp): """FusedMoE layer for MoE models. @@ -751,11 +789,11 @@ class FusedMoE(CustomOp): self.global_num_experts = num_experts + num_redundant_experts # we padding globally so EP buffer allocation works - if (quant_config and quant_config.get_name() == "mxfp4" - and (current_platform.is_rocm() - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)): - hidden_size = round_up(hidden_size, 256) + if quant_config and quant_config.get_name() == "mxfp4": + from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 + should_use_flashinfer_mxfp4) + if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): + hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config @@ -782,6 +820,12 @@ class FusedMoE(CustomOp): ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) else: self.local_num_experts, self.expert_map = (self.global_num_experts, None) @@ -868,8 +912,6 @@ class FusedMoE(CustomOp): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) - if isinstance(self.quant_method, FusedMoEMethodBase): - self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config) # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6262904e4dca1..2ea6383d5ae90 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from math import prod -from typing import Any, Optional, final +from typing import Optional, final import torch @@ -150,15 +150,23 @@ class FusedMoEPrepareAndFinalize(ABC): @abstractmethod def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -186,11 +194,15 @@ class FusedMoEPrepareAndFinalize(ABC): raise NotImplementedError @abstractmethod - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + ) -> None: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -368,7 +380,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC): workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): """ This function computes the intermediate result of a Mixture of Experts @@ -454,18 +465,27 @@ class FusedMoEModularKernel(torch.nn.Module): f"{fused_experts.activation_formats[0]}") def _do_fused_experts( - self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, - a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, global_num_experts: int, local_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor: + self, + fused_out: Optional[torch.Tensor], + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -509,7 +529,7 @@ class FusedMoEModularKernel(torch.nn.Module): workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) return fused_out @@ -533,7 +553,6 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -541,6 +560,9 @@ class FusedMoEModularKernel(torch.nn.Module): CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = cdiv(M, CHUNK_SIZE) + # TODO(bnell): get rid of one level here, update slice functions + # to nops on num_chunks==1 + if not self.fused_experts.supports_chunking() or num_chunks == 1: return self._do_fused_experts( fused_out=None, @@ -562,7 +584,7 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) # Chunking required case assert num_chunks > 1 @@ -618,15 +640,6 @@ class FusedMoEModularKernel(torch.nn.Module): expert_num_tokens=c_expert_num_tokens, expert_num_tokens_cpu=c_expert_num_tokens_cpu) - m = None - if extra_expert_args is not None and 'm' in extra_expert_args: - m = extra_expert_args.get('m') - - if extra_expert_args is not None: - chunked_extra_expert_args = extra_expert_args - else: - chunked_extra_expert_args = {} - for chunk_idx in range(num_chunks): c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( slice_input_tensors(chunk_idx)) @@ -637,11 +650,6 @@ class FusedMoEModularKernel(torch.nn.Module): expert_tokens_meta, c_topk_ids, local_num_experts, expert_map) - s = chunk_idx * CHUNK_SIZE - e = min(s + CHUNK_SIZE, M) - - if m is not None: - chunked_extra_expert_args['m'] = e - s self._do_fused_experts( fused_out=slice_output_tensor(chunk_idx), a1=a1, @@ -662,7 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=chunked_extra_expert_args) + ) return fused_out @@ -684,9 +692,6 @@ class FusedMoEModularKernel(torch.nn.Module): a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - extra_expert_args: Optional[dict] = None, - extra_prepare_args: Optional[dict] = None, - extra_finalize_args: Optional[dict] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -719,12 +724,6 @@ class FusedMoEModularKernel(torch.nn.Module): - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. - - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to - fused_experts.apply. - - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass - to prepare. - - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass - to finalize. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -748,7 +747,6 @@ class FusedMoEModularKernel(torch.nn.Module): expert_map, apply_router_weight_on_input, self.fused_experts.quant_config, - extra_prepare_args, ) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. @@ -786,12 +784,15 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) self.prepare_finalize.finalize( - output, fused_out, topk_weights, topk_ids, + output, + fused_out, + topk_weights, + topk_ids, apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), - extra_finalize_args) + ) return output diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index d35bd0098b3ca..582ae3e12c289 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F -import torch_xla.experimental.custom_kernel # noqa: F401 def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: @@ -41,6 +40,7 @@ def fused_moe( gating_output: [*, num_experts] """ assert expert_map is None, "expert_map is not supported for pallas MoE." + import torch_xla.experimental.custom_kernel # noqa: F401 orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.shape[:-1].numel() diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index d9059f50b4459..16a155e718478 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -82,7 +82,8 @@ def moe_permute( n_local_expert: int = -1, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1 + fill_invalid_expert: int = -1, + permuted_hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -95,14 +96,17 @@ def moe_permute( - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices + - permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor. + If None, the output tensor will be created in this function. Returns: - permuted_hidden_states (torch.Tensor): permuted activation. - - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states + - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states + if original scale not per-tensor scaling - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. @@ -122,11 +126,16 @@ def moe_permute( 1) // align_block_size * align_block_size if n_local_expert == -1: n_local_expert = n_expert - permuted_hidden_states = torch.empty( - (permuted_row_size, n_hidden), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + if permuted_hidden_states is None: + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), ( + f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}" + f" but got {permuted_hidden_states.size()}") + token_expert_indices = torch.arange(0, n_token * topk, dtype=torch.int32, @@ -153,7 +162,8 @@ def moe_permute( align_block_size, permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, permuted_idx, m_indices) - if a1q_scale is not None: + + if a1q_scale is not None and a1q_scale.dim() > 1: a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk] return (permuted_hidden_states, a1q_scale, expert_first_token_offset, @@ -185,6 +195,7 @@ def moe_unpermute( n_hidden = permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" + torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, inv_permuted_idx, expert_first_token_offset, topk, out) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 46931f2dd7c78..401f37922b7bb 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional, Union import pplx_kernels as pplx import torch @@ -21,7 +21,7 @@ def pplx_hidden_dim_scale_bytes( max_num_tokens: int, hidden_dim: int, in_dtype: torch.dtype, - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_act_token_quant: bool, block_shape: Optional[list[int]], ): @@ -32,6 +32,7 @@ def pplx_hidden_dim_scale_bytes( # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to 4 * sizeof(float32) (x4 for alignment) if quant_dtype is not None: + assert isinstance(quant_dtype, torch.dtype) assert quant_dtype.itemsize == 1 hidden_dim_bytes = hidden_dim * quant_dtype.itemsize elem_size = torch.float32.itemsize @@ -89,12 +90,16 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return self.num_dispatchers_ def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -213,11 +218,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return expert_x, expert_x_scale, expert_tokens_meta, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 696c7cdba9a7b..567a0a88fec0a 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -38,7 +38,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -50,32 +49,26 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - if (extra_prepare_args is not None - and extra_prepare_args.get("skip_quant", True)): - # Skip quantization if explicitly requested - return a1, None, None, None, None - a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: - if (extra_finalize_args is not None - and extra_finalize_args.get("skip_weight_reduce", True)): - assert output.shape == fused_expert_output.shape - output.copy_(fused_expert_output) - else: - if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): - weight_and_reduce_impl = TopKWeightAndReduceContiguous() - weight_and_reduce_impl.apply( - output=output, - fused_expert_output=fused_expert_output, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input) + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 9d0ff2e06190e..6cd81d97f0298 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used() + if self.allow_deep_gemm and (is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( @@ -119,21 +119,31 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): local_num_experts, expert_tokens_meta) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) - or is_blackwell_deep_gemm_e8m0_used())) + or is_deep_gemm_e8m0_used())) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None @@ -158,5 +168,4 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2, expert_tokens_meta, apply_router_weight_on_input, - extra_expert_args, ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 966471b5c59b4..4c3e700ad3990 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod -from typing import Any, Optional, Union +from typing import Optional, Union import torch @@ -189,7 +189,7 @@ def moe_kernel_quantize_input( return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) - elif quant_dtype == torch.uint8: # nvfp4 + elif quant_dtype == "nvfp4": return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled) @@ -252,17 +252,3 @@ def _validate_scale_shape( assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" - - -def extract_required_args( - extra_args: Optional[dict[str, Any]], - required_keys: list[str], -) -> tuple[Any, ...]: - if extra_args is None: - raise ValueError("`extra_args` must be provided.") - - missing_keys = [k for k in required_keys if k not in extra_args] - if missing_keys: - raise ValueError(f"Missing keys in `extra_args`: {missing_keys}") - - return tuple(extra_args[k] for k in required_keys) diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index 8ffc700ca5cde..0b87acc851208 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch from einops import rearrange @@ -453,7 +455,14 @@ class _attention(torch.autograd.Function): lightning_attention_ = _attention.apply -def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): +def lightning_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ed: torch.Tensor, + block_size: int = 256, + kv_history: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: """ Apply lightning attention algorithm to compute attention efficiently. diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 75391c51f7754..c0fcacd1e6ee9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -42,7 +42,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", - "QQQLinearMethod", "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", @@ -53,6 +52,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "HQQMarlinMethod", "QuarkLinearMethod", "ModelOptNvFp4LinearMethod", + "PetitNvFp4LinearMethod", ] @@ -200,11 +200,10 @@ class UnquantizedLinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: + from vllm.model_executor.layers.utils import check_cpu_sgl_kernel N, K = layer.weight.size() dtype = layer.weight.dtype - if (torch._C._cpu._is_amx_tile_supported() - and dtype == torch.bfloat16 and N % 32 == 0 - and K % 32 == 0): + if check_cpu_sgl_kernel(N, K, dtype): packed_weight = torch.ops._C.convert_weight_packed( layer.weight) assert packed_weight.size() == layer.weight.size() @@ -216,7 +215,8 @@ class UnquantizedLinearMethod(LinearMethodBase): else: logger.warning( "CPU SGL kernels require Intel AMX support," - " bfloat16 weight, IC and OC are divisible by 32.") + " bf16/fp16/int8 weight, IC and OC are divisible by " + "32 and 16.") layer.use_cpu_sgl = False def apply(self, @@ -233,10 +233,10 @@ class LinearBase(CustomOp): Args: input_size: input dimension of the linear layer. output_size: output dimension of the linear layer. - bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + prefix: Prefix for parameter names. return_bias: If true, return bias together with outputs in forward pass. """ @@ -378,13 +378,14 @@ class MergedReplicatedLinear(ReplicatedLinear): Args: input_size: input dimension of the linear layer. - output_size: output dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -437,7 +438,7 @@ class MergedReplicatedLinear(ReplicatedLinear): shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] - param[shard_offset:shard_offset + shard_size] = loaded_weight + param.data[shard_offset:shard_offset + shard_size] = loaded_weight @CustomOp.register("column_parallel_linear") @@ -692,8 +693,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data = param.data output_dim = getattr(param, "output_dim", None) - # Special case for AQLM codebooks. - is_metadata = getattr(param, "is_metadata", False) # Special case for per-tensor scale to load scalar into fused array. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) @@ -781,13 +780,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # Special case for AQLM codebooks. - elif is_metadata: - # metadata indicates fixed size concatenated along dim 0 - shard_size = loaded_weight.shape[0] - shard_offset = loaded_shard_id * shard_size - param_data = param_data.narrow(0, shard_offset, shard_size) - # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( @@ -1081,8 +1073,6 @@ class QKVParallelLinear(ColumnParallelLinear): param_data = param.data output_dim = getattr(param, "output_dim", None) - # Special case for AQLM codebooks. - is_metadata = getattr(param, "is_metadata", False) # Special case for per-tensor scales in fused case. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) @@ -1204,13 +1194,6 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # Special case for for AQLM codebooks. - elif is_metadata: - # metadata indicates fixed size concatenated along dim 0 - shard_size = loaded_weight.shape[0] - shard_index = ["q", "k", "v"].index(loaded_shard_id) - param_data = param_data.narrow(0, shard_index * shard_size, - shard_size) # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( @@ -1396,7 +1379,7 @@ class RowParallelLinear(LinearBase): return output, output_bias def extra_repr(self) -> str: - s = f"input_features={self.input_size_per_partition}" + s = f"in_features={self.input_size_per_partition}" s += f", output_features={self.output_size}" s += f", bias={self.bias is not None}" s += f", tp_size={self.tp_size}" diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index daebe46f6f771..a524e13405807 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Iterable +from typing import TYPE_CHECKING import torch +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -class MambaBase(ABC): +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class MambaBase(AttentionLayerBase): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. @@ -32,3 +38,8 @@ class MambaBase(ABC): @abstractmethod def mamba_type(self) -> str: pass + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this Mamba layer.""" + pass diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 3c7322260df43..e704bfd451bce 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -27,6 +30,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -183,22 +188,26 @@ class MambaMixer(MambaBase, CustomOp): def forward(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): if not envs.VLLM_USE_V1: - return CustomOp.forward(self, hidden_states, mamba_cache_params) + CustomOp.forward(self, hidden_states, output, mamba_cache_params) else: - return self.forward_cuda( + torch.ops.vllm.mamba_mixer( hidden_states, - mamba_cache_params, + output, + self.prefix, ) def forward_native(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): pass def forward_cuda(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): """ Run the Mamba-1 SSM pipeline. @@ -237,6 +246,7 @@ class MambaMixer(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states = mamba1_metadata.has_initial_states + num_padded_decodes = mamba1_metadata.num_padded_decodes else: assert isinstance(attn_metadata, AttentionMetadata) assert mamba_cache_params is not None @@ -248,6 +258,7 @@ class MambaMixer(MambaBase, CustomOp): has_initial_states = None if context_lens_tensor is not None: has_initial_states = context_lens_tensor > 0 + num_padded_decodes = attn_metadata.num_decode_tokens # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -267,6 +278,7 @@ class MambaMixer(MambaBase, CustomOp): num_decodes = attn_metadata.num_decode_tokens # token count (=request) has_prefill = num_prefill_tokens > 0 has_decode = num_decode_tokens > 0 + num_actual_tokens = num_prefill_tokens + num_decode_tokens prefill_decode_split = split_batch_to_prefill_and_decode( hidden_states_BC, @@ -278,6 +290,7 @@ class MambaMixer(MambaBase, CustomOp): num_decode_tokens, num_prefills, num_decodes, + num_padded_decodes, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d @@ -371,7 +384,7 @@ class MambaMixer(MambaBase, CustomOp): else: out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] - return out + output[:num_actual_tokens] = out def get_state_dtype(self) -> tuple[torch.dtype]: assert self.model_config is not None @@ -394,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba1" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba1_attn import ( + Mamba1AttentionBackend) + return Mamba1AttentionBackend + def _time_proj_bias(self) -> Optional[torch.Tensor]: if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: return self.dt_proj.bias.float() @@ -421,18 +439,27 @@ def split_batch_to_prefill_and_decode( num_decode_tokens: int, num_prefills: int, num_decodes: int, + num_padded_decodes: int, ) -> PrefillDecodeSplit: + num_actual_tokens = num_prefill_tokens + num_padded_decodes + if envs.VLLM_USE_V1: # In v1, decode tokens come first, then prefill tokens. hidden_states_BC_d, hidden_states_BC_p = torch.split( - hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1) - gate_d, gate_p = torch.split(gate, - [num_decode_tokens, num_prefill_tokens], + hidden_states_BC[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) + gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], dim=-1) + + # num_padded_decodes accounts for CUDA graph padding when applicable state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, [num_decodes, num_prefills], dim=0) + state_indices_tensor[:num_padded_decodes + num_prefills], + [num_padded_decodes, num_prefills], + dim=0) query_start_loc_p = (query_start_loc[-num_prefills - 1:] - - num_decodes if num_prefills > 0 else None) + num_padded_decodes if num_prefills > 0 else None) has_initial_states_p = has_initial_states[-num_prefills:] if ( has_initial_states is not None and num_prefills > 0) else None else: @@ -459,3 +486,32 @@ def split_batch_to_prefill_and_decode( query_start_loc_p=query_start_loc_p, has_initial_states_p=has_initial_states_p, ) + + +def mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None) + + +def mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="mamba_mixer", + op_func=mamba_mixer, + mutates_args=["output"], + fake_impl=mamba_mixer_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 743e520ec8ee1..bb3fdd38dbef3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba2" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import ( + Mamba2AttentionBackend) + return Mamba2AttentionBackend + def mamba_mixer2( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 66674d1a6f251..280a9e45e662e 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -54,6 +54,16 @@ class MambaStateDtypeCalculator: return (conv_state_dtype, temporal_state_dtype) + @classmethod + def short_conv_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, + model_dtype) + return (conv_state_dtype, ) + class MambaStateShapeCalculator: @@ -122,6 +132,20 @@ class MambaStateShapeCalculator: tp_world_size), head_dim, state_size) return conv_state_shape, temporal_state_shape + @classmethod + def short_conv_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + conv_dim = divide(intermediate_size, tp_world_size) + conv_state_shape = (conv_kernel - 1, conv_dim) + if not use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + return (conv_state_shape, ) + @classmethod def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): """Compute the increase in group numbers to account for diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py new file mode 100644 index 0000000000000..335191a5c82c1 --- /dev/null +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionMetadata) + + +@CustomOp.register("short_conv") +class ShortConv(MambaBase, CustomOp): + + def __init__(self, + config, + dim: int, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.conv_dim = dim + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = ColumnParallelLinear( + input_size=self.L_cache, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.conv1d", + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv.weight.data = self.conv.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[dim] * 3, + bias=self.bias, + prefix=f"{prefix}.in_proj", + ) + self.out_proj = RowParallelLinear( + input_size=dim, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.out_proj", + ) + + assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1") + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + self.kv_cache = [(torch.tensor([]), )] + + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + def forward_native( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + return + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + torch.ops.vllm.short_conv( + hidden_states, + output, + self.prefix, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + forward_context = get_forward_context() + # ShortConvAttentionMetadata contains metadata necessary for the + # short_conv triton kernels to operate in continuous batching and in + # chunked prefill modes; they are computed at top-level model forward + # since they stay the same and reused for all mamba layers in the same + # iteration. + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + conv_metadata = attn_metadata + assert isinstance(attn_metadata, ShortConvAttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states + + BCx, _ = self.in_proj(hidden_states) + + B, C, x = BCx.chunk(3, dim=-1) + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), + self.conv.weight.size(2)) + + if attn_metadata is None: + # V1 profile run + Bx = (B * x).contiguous() + hidden_states = C * Bx + contextualized_states, _ = self.out_proj(hidden_states) + return contextualized_states + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_decodes + num_prefill_tokens + + # NOTE: V1 puts decode before prefill + # Separate prefill and decode by splitting varlen input + # Split along token dimension + B_d, B_p = torch.split( + B[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + C_d, C_p = torch.split( + C[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + x_d, x_p = torch.split( + x[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + + conv_output_list = [] + + if has_prefill: + Bx_p = (B_p * x_p).transpose(0, 1) + if conv_metadata.cu_seqlen is None: + conv_metadata = update_metadata(Bx_p, query_start_loc_p, + conv_metadata) + Bx = causal_conv1d_fn(Bx_p, + conv_weights, + self.conv.bias, + activation=None, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + metadata=conv_metadata, + query_start_loc=query_start_loc_p).transpose( + 0, 1)[:num_prefill_tokens] + + y = C_p * Bx + conv_output_list.append(y) + + if has_decode: + Bx_d = (B_d * x_d).contiguous() + Bx = causal_conv1d_update( + Bx_d, + conv_state, + conv_weights, + self.conv.bias, + activation=None, + conv_state_indices=state_indices_tensor_d) + y = C_d * Bx + conv_output_list.insert(0, y) + + # Merge prefill and decode outputs before passing to gated MLP + hidden_states = torch.vstack(conv_output_list) + + # Final linear projection + output[:num_actual_tokens], _ = self.out_proj(hidden_states) + + def get_state_dtype(self) -> tuple[torch.dtype, ...]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.short_conv_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...]]: + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=get_tensor_model_parallel_world_size(), + intermediate_size=self.conv_dim, + conv_kernel=self.L_cache, + ) + + @property + def mamba_type(self) -> str: + return "short_conv" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + return ShortConvAttentionBackend + + +def short_conv( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + conv_metadata=None) + + +def short_conv_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="short_conv", + op_func=short_conv, + mutates_args=["output"], + fake_impl=short_conv_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index e2162e5cbf956..eebf7b2508dbc 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Set from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union, cast import torch import torch.nn as nn @@ -19,7 +19,8 @@ from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.tasks import PoolingTask -from vllm.utils import resolve_obj_by_qualname +from vllm.utils import current_stream, resolve_obj_by_qualname +from vllm.v1.pool.metadata import PoolingCursor from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] @@ -172,6 +173,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: def get_classification_activation_function(config: PretrainedConfig): + # Implement alignment with transformers ForSequenceClassificationLoss + # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 + problem_type = getattr(config, "problem_type", "") + if problem_type == "regression": + return PoolerIdentity() + if problem_type == "single_label_classification": + return PoolerClassify() + if problem_type == "multi_label_classification": + return PoolerMultiLabelClassify() return PoolerClassify() @@ -196,6 +206,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): def build_output( all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: + # Pooling models D2H & synchronize occurs here + if isinstance(all_data, list): + all_data = [d.to("cpu", non_blocking=True) for d in all_data] + else: + all_data = all_data.to("cpu", non_blocking=True) + current_stream().synchronize() + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] return PoolerOutput(outputs=all_outputs) @@ -222,40 +239,21 @@ class PoolingMethod(nn.Module, ABC): def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return PoolingParamsUpdate() - @abstractmethod - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Note: - `prompt_len=None` means `prompt_len=len(hidden_states)`. - """ - raise NotImplementedError - @abstractmethod def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: raise NotImplementedError def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - - if isinstance(hidden_states, list): - return [ - self.forward_one(h, prompt_len) - for h, prompt_len in zip(hidden_states, prompt_lens) - ] - - return self.forward_all(hidden_states, prompt_lens) + pooling_cursor = pooling_metadata.pooling_cursor + return self.forward_all(hidden_states, pooling_cursor) class CLSPool(PoolingMethod): @@ -263,24 +261,15 @@ class CLSPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with CLS pooling" - - return hidden_states[0] - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - first_token_flat_indices = torch.zeros_like(prompt_lens) - first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] - return hidden_states[first_token_flat_indices] + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with CLS pooling" + + return hidden_states[pooling_cursor.first_token_indices_gpu] class LastPool(PoolingMethod): @@ -288,20 +277,12 @@ class LastPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return hidden_states[-1] - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 - return hidden_states[last_token_flat_indices] + return hidden_states[pooling_cursor.last_token_indices_gpu] class AllPool(PoolingMethod): @@ -309,22 +290,19 @@ class AllPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with ALL pooling" - - return hidden_states - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - return list(hidden_states.split_with_sizes(prompt_lens.tolist())) + + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with ALL pooling" + + hidden_states_lst = list( + hidden_states.split( + pooling_cursor.num_scheduled_tokens_cpu.tolist())) + return [hidden_states_lst[i] for i in pooling_cursor.index] class MeanPool(PoolingMethod): @@ -332,31 +310,25 @@ class MeanPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with MEAN pooling" - - return hidden_states.mean(dim=0, dtype=torch.float32) - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: + + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with MEAN pooling" + + prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device, + non_blocking=True) + # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) - start_indices = torch.cat([ - torch.tensor([0], device=hidden_states.device), - torch.cumsum(prompt_lens[:-1], dim=0) - ]) - end_indices = torch.cumsum(prompt_lens, dim=0) - return (cumsum[end_indices - 1] - cumsum[start_indices] + + start_indices = pooling_cursor.first_token_indices_gpu + end_indices = pooling_cursor.last_token_indices_gpu + return (cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) @@ -409,6 +381,12 @@ class PoolerNormalize(PoolerActivation): return x.to(pooled_data.dtype) +class PoolerMultiLabelClassify(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + + class PoolerClassify(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: @@ -457,11 +435,37 @@ class EmbeddingPoolerHead(PoolerHead): def __init__(self) -> None: super().__init__(activation=PoolerNormalize()) + # Load ST projector if available + from vllm.config import get_current_vllm_config + from vllm.model_executor.models.adapters import _load_st_projector + + vllm_config = get_current_vllm_config() + self.projector = _load_st_projector( + vllm_config.model_config) if vllm_config else None + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + # Apply ST projector + if self.projector is not None: + projector = cast(nn.Module, self.projector) + + def _proj(x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + y = projector(x.to(torch.float32)) + return y.to(orig_dtype) + + if isinstance(pooled_data, torch.Tensor): + pooled_data = _proj(pooled_data) + else: + pooled_data = [_proj(t) for t in pooled_data] + pooling_params = get_pooling_params(pooling_metadata) + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, embedding_dimension] + # for matryoshka representation dimensions_list = [ pooling_param.dimensions for pooling_param in pooling_params @@ -652,6 +656,10 @@ class ClassifierPooler(Pooler): ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_size] + if self.classifier is not None: # apply classifier once on the full batch if possible if isinstance(pooled_data, torch.Tensor): @@ -702,12 +710,6 @@ class DispatchPooler(Pooler): ) -> PoolerOutput: poolers_by_task = self.poolers_by_task - if isinstance(hidden_states, list): - hidden_states_lst = hidden_states - else: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - hidden_states_lst = list(hidden_states.split(prompt_lens.tolist())) - outputs = list[PoolingSequenceGroupOutput]() offset = 0 for task, group in groupby(get_tasks(pooling_metadata)): @@ -718,7 +720,7 @@ class DispatchPooler(Pooler): num_items = len(list(group)) group_output: PoolerOutput = pooler( - hidden_states_lst[offset:offset + num_items], + hidden_states, pooling_metadata[offset:offset + num_items], ) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 8d63027e1863f..d73fcf368f261 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -7,7 +7,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) QuantizationMethods = Literal[ - "aqlm", "awq", "deepspeedfp", "tpu_int8", @@ -16,7 +15,6 @@ QuantizationMethods = Literal[ "fbgemm_fp8", "modelopt", "modelopt_fp4", - "marlin", "bitblas", "gguf", "gptq_marlin_24", @@ -26,7 +24,6 @@ QuantizationMethods = Literal[ "gptq", "compressed-tensors", "bitsandbytes", - "qqq", "hqq", "experts_int8", "neuron_quant", @@ -38,6 +35,7 @@ QuantizationMethods = Literal[ "rtn", "inc", "mxfp4", + "petit_nvfp4", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -88,7 +86,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: # lazy import to avoid triggering `torch.compile` too early from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig - from .aqlm import AQLMConfig from .auto_round import AutoRoundConfig from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig @@ -108,19 +105,17 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .hqq_marlin import HQQMarlinConfig from .inc import INCConfig from .ipex_quant import IPEXConfig - from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config from .neuron_quant import NeuronQuantConfig + from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config - from .qqq import QQQConfig from .rtn import RTNConfig from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig method_to_config: dict[str, type[QuantizationConfig]] = { - "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, "tpu_int8": Int8TpuConfig, @@ -128,7 +123,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, - "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, @@ -139,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "ptpc_fp8": PTPCFp8Config, - "qqq": QQQConfig, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, @@ -151,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "rtn": RTNConfig, "inc": INCConfig, "mxfp4": Mxfp4Config, + "petit_nvfp4": PetitNvFp4Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py deleted file mode 100644 index 2ea8c5dc51132..0000000000000 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ /dev/null @@ -1,376 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Supports AQLM compression, see https://github.com/Vahe1994/AQLM -# and https://arxiv.org/pdf/2401.06118.pdf - -import math -from typing import Any, Optional - -import torch -import torch.nn.functional as F -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs - - -def get_int_dtype(nbits: int) -> torch.dtype: - if nbits <= 8: - return torch.int8 - if nbits <= 16: - return torch.int16 - if nbits <= 32: - return torch.int32 - if nbits <= 64: - return torch.int64 - raise ValueError(f"No dtype available for {nbits}-bit codebooks") - - -@torch.inference_mode() -def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: - return data.to(torch.int64) % (2**nbits) - - -def dequantize_weight(codes: torch.Tensor, - codebooks: torch.Tensor, - scales: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Decode float weights from quantization codes. Differentiable. - :param codes: tensor of integer quantization codes, shape - [*dims, num_out_groups, num_in_groups, num_codebooks] - :param codebooks: tensor of vectors for each quantization code, - [num_codebooks, codebook_size, out_group_size, in_group_size] - :param scales: weight will be multiplied by this factor, must be - broadcastble with - [*dims, out_groups, num_in_groups, out_group_size, in_group_size] - :return: reconstructed weight tensor of shape - [*dims, num_in_groups*group_size] - """ - num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] - num_codebooks, codebook_size, out_group_size, in_group_size = \ - codebooks.shape - out_features = num_out_groups * out_group_size - in_features = num_in_groups * in_group_size - codebook_offsets = torch.arange( - 0, num_codebooks * codebook_size, codebook_size, - device=codes.device) # shape: [num_codebooks] - reconstructed_weight_flat = F.embedding_bag( - codes.flatten(0, -2) + codebook_offsets, - codebooks.flatten(0, 1).flatten(-2, -1), - mode="sum" - ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size - # * in_group_size] - - reconstructed_weight_groupwise = reconstructed_weight_flat.view( - list(codes.shape[:-3]) + - [num_out_groups, num_in_groups, out_group_size, in_group_size]) - if scales is not None: - reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul( - scales) - return reconstructed_weight_groupwise.swapaxes( - -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) - - -def dequantize_gemm( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - bias: Optional[torch.Tensor], -) -> torch.Tensor: - dequantized_weight = dequantize_weight( - unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), - codebooks, - scales, - ) - return F.linear(input, dequantized_weight, bias) - - -# Generic dequantization, slow but flexible. -def generic_dequantize_gemm( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - output_partition_sizes: list[int], - bias: Optional[torch.Tensor], -) -> torch.Tensor: - output_shape = input.shape[:-1] + (scales.shape[0], ) - output = torch.empty(output_shape, dtype=input.dtype, device=input.device) - num_outputs = len(output_partition_sizes) - - # break the inputs and codebooks apart then combine the outputs. - # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big - # multiply at the end. - num_codebooks = codebooks.shape[0] // num_outputs - assert (scales.shape[0] == codes.shape[0]) - assert (sum(output_partition_sizes) == scales.shape[0]) - output_offset = 0 - codebooks_offset = 0 - for output_size in output_partition_sizes: - shard_output = dequantize_gemm( - input, codes.narrow(0, output_offset, output_size), - codebooks.narrow(0, codebooks_offset, num_codebooks), - scales.narrow(0, output_offset, output_size), None - if bias is None else bias.narrow(0, output_offset, output_size)) - - output_slice = output.narrow(-1, output_offset, output_size) - assert (output_slice.shape == shard_output.shape) - output_slice.copy_(shard_output) - output_offset += output_size - codebooks_offset += num_codebooks - return output - - -# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8 -# at 6 and 9 times faster than the generic version above, respectively. -def optimized_dequantize_gemm( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - output_partition_sizes: list[int], - bias: Optional[torch.Tensor], -) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - - if bias is None: - # scaling the output is fastest, so we do that when possible. - output = F.linear(input, weights, bias) - orig_shape = output.shape - flattened_output = output.view(-1, output.size(-1)) - f_scales = scales.view(-1, scales.shape[0]) - b_scales = f_scales.expand(flattened_output.shape[0], -1) - flattened_output *= b_scales - return output.view(orig_shape) - else: - b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( - -1, weights.shape[1]) - weights *= b_scales - return F.linear(input, weights, bias) - - -class AQLMConfig(QuantizationConfig): - """Config class for AQLM. - - Reference: https://github.com/Vahe1994/AQLM - """ - - def __init__( - self, - in_group_size: int, - nbits_per_codebook: int, - num_codebooks: int, - out_group_size: int, - ) -> None: - super().__init__() - self.in_group_size = in_group_size - self.nbits_per_codebook = nbits_per_codebook - self.num_codebooks = num_codebooks - self.out_group_size = out_group_size - - # out_group_size > 1 is untested, and probably won't work as-is. - assert (self.out_group_size == 1) - self.pack_factor = (self.in_group_size * self.out_group_size) - - def __repr__(self) -> str: - return (f"AQLMConfig(in_group_size={self.in_group_size}, " - f"nbits_per_codebook={self.nbits_per_codebook}, " - f"num_codebooks={self.num_codebooks}, " - f"out_group_size={self.out_group_size})") - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "aqlm" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 60 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return [] # no extra configs. - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "AQLMConfig": - in_group_size = cls.get_from_keys(config, ["in_group_size"]) - nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) - num_code_books = cls.get_from_keys(config, ["num_codebooks"]) - out_group_size = cls.get_from_keys(config, ["out_group_size"]) - return cls(in_group_size, nbits_per_codebook, num_code_books, - out_group_size) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AQLMLinearMethod"]: - if isinstance(layer, LinearBase): - return AQLMLinearMethod(self) - return None - - -class AQLMLinearMethod(LinearMethodBase): - """Linear method for AQLM. - - Args: - quant_config: The AQLM quantization config. - """ - - def __init__(self, quant_config: AQLMConfig): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - del output_size # Unused. - del input_size # Unused. - - if params_dtype != torch.half: - raise ValueError("Only half is currently supported by aqlm") - if input_size_per_partition % self.quant_config.in_group_size != 0: - raise ValueError( - "The input size is not aligned with the quantized " - "weight shape. This can be caused by too large " - "tensor parallel size.") - - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.out_group_size != 0: - raise ValueError( - "The output size is not aligned with the quantized " - "weight shape. This can be caused by too large " - "tensor parallel size.") - - codes = Parameter( - torch.empty( - # There could actually be two pack factors, one along input and - # one along output, but we don't currently support - # out_group_size, and only the one along output needs to be - # marked with "packed_dim" in order for QKVLinear to work. - output_size_per_partition, - input_size_per_partition // self.quant_config.pack_factor, - self.quant_config.num_codebooks, - dtype=get_int_dtype(self.quant_config.nbits_per_codebook), - ), - requires_grad=False, - ) - - set_weight_attrs( - codes, - { - "input_dim": 1, - "output_dim": 0, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }, - ) - - codebooks = Parameter( - torch.empty( - self.quant_config.num_codebooks * len(output_partition_sizes), - 2**self.quant_config.nbits_per_codebook, - self.quant_config.out_group_size, - self.quant_config.in_group_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs( - codebooks, - { - # metadata indicates fixed size concatenated along dim 0 - "is_metadata": True, - "output_partition_sizes": output_partition_sizes - }, - ) - - scales = Parameter( - torch.empty( - ( - output_size_per_partition // - self.quant_config.out_group_size, - 1, - 1, - 1, - ), - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs( - scales, - { - "output_dim": 0, - "packed_dim": 0, - "pack_factor": self.quant_config.out_group_size - }, - ) - - layer.register_parameter("codes", codes) - set_weight_attrs(codes, extra_weight_attrs) - layer.register_parameter("codebooks", codebooks) - set_weight_attrs(codebooks, extra_weight_attrs) - layer.register_parameter("scales", scales) - set_weight_attrs(scales, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - codebooks = layer.codebooks - codes = layer.codes - scales = layer.scales - output_partition_sizes = getattr(codebooks, "output_partition_sizes", - []) - - nbooks = codes.shape[2] - ingroups = codebooks.shape[3] - outgroups = codebooks.shape[2] - bits = codebooks.shape[1] - - # We support these formats with dedicated gemm and decompression - # kernels. - if ingroups == 8 and outgroups == 1 and ( - (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)): - - # thresholds determined by timings on an A6000, one GPU - use_gemv = math.prod(x.shape[:-1]) <= 6 - - return ops.aqlm_gemm( - x, - codes, - codebooks, - scales, - output_partition_sizes, - bias, - ) if use_gemv else optimized_dequantize_gemm( - x, - codes, - codebooks, - scales, - output_partition_sizes, - bias, - ) - - # fall back all unoptimized formats - return generic_dequantize_gemm( - x, - codes, - codebooks, - scales, - output_partition_sizes, - bias, - ) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index a9e967e608e96..fb285413ba9ef 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig): if isinstance(layer, FusedMoE): if use_marlin: - return AWQMoEMethod(quant_args_marlin) + return AWQMoEMethod(quant_args_marlin, layer.moe) from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) @@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig): } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix) - return GPTQMarlinMoEMethod(quant_args_marlin) + return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index fe42e26a17061..af602eb9aca38 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig): } awq_marlin_config = AWQMarlinConfig.from_config( marlin_compatible_config_dict) - return AWQMoEMethod(awq_marlin_config) + return AWQMoEMethod(awq_marlin_config, layer.moe_config) return None diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index ed7ffb21e85aa..287d66b06d6e9 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig): "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - return AWQMoEMethod(self) + return AWQMoEMethod(self, layer.moe_config) return None @classmethod @@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase): class AWQMoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: AWQMarlinConfig): + def __init__( + self, + quant_config: AWQMarlinConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config if self.quant_config.weight_bits != 4: raise ValueError("AWQMoEMethod only supports 4bit now.") @@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `AWQMoEMethod` yet.") @@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, @@ -535,4 +543,4 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_map=expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, - workspace=layer.workspace) \ No newline at end of file + workspace=layer.workspace) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 0204ff46852f4..b7897a43793c7 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -7,6 +7,7 @@ import torch from packaging import version from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig): return UnquantizedLinearMethod() return BitsAndBytesLinearMethod(self) elif isinstance(layer, FusedMoE): - return BitsAndBytesMoEMethod(self) + return BitsAndBytesMoEMethod(self, layer.moe_config) return None @@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): quant_config: The BitsAndBytes quantization config. """ - def __init__(self, quant_config: BitsAndBytesConfig): + def __init__( + self, + quant_config: BitsAndBytesConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) try: import bitsandbytes if version.parse( @@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): raise ImportError("Please install bitsandbytes>=0.46.1 via " "`pip install bitsandbytes>=0.46.1` to use " "bitsandbytes quantizer.") from err - self.topk_indices_dtype = None self.quant_config = quant_config def create_weights( @@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 637a84372990a..245cf122ebab1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -26,10 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -200,8 +200,10 @@ class CompressedTensorsConfig(QuantizationConfig): format ) if format is not None else is_activation_quantization_format( quant_format) - if act_quant_format: - input_activations = quant_config.get("input_activations") + # TODO(czhu): w4a8fp8 is in packed-quantized format + # but needs input activation quantization + input_activations = quant_config.get("input_activations") + if act_quant_format or input_activations: # The only case where we have activation quant supported # but no input_activations provided in the config # should be w8a16fp8 w8a16fp8 can also run for cases where @@ -352,6 +354,28 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation + def _is_fp8_w4a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + if not weight_quant or not input_quant: + return False + is_weight_4_bits = weight_quant.num_bits == 4 + is_activation_8_bits = input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + # Only per-group symmetric weight (4bit) + # + per-tok symmetric activation (8bit) quantization supported. + return (is_weight_4_bits and is_activation_8_bits and is_token + and is_symmetric and is_dynamic) + + def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w4a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) @@ -401,19 +425,30 @@ class CompressedTensorsConfig(QuantizationConfig): weight_quant: BaseModel, input_quant: BaseModel, format: Optional[str] = None) -> "CompressedTensorsScheme": + + # use the per-layer format if defined, otherwise, use global format + format = format if format is not None else self.quant_format + # Detect If Mixed Precision if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() + if self._is_fp8_w4a8_sm90(weight_quant, input_quant): + return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) + if self._is_wNa16_group_channel(weight_quant, input_quant): - if (self.quant_format == CompressionFormat.marlin_24.value + if (format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, group_size=weight_quant.group_size) - if (self.quant_format == CompressionFormat.pack_quantized.value + if (format == CompressionFormat.pack_quantized.value and weight_quant.num_bits in WNA16_SUPPORTED_BITS): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, @@ -422,10 +457,7 @@ class CompressedTensorsConfig(QuantizationConfig): group_size=weight_quant.group_size, actorder=weight_quant.actorder) - act_quant_format = is_activation_quantization_format( - format - ) if format is not None else is_activation_quantization_format( - self.quant_format) + act_quant_format = is_activation_quantization_format(format) if act_quant_format: if self._is_fp4a4_nvfp4(weight_quant, input_quant): if cutlass_fp4_supported( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 839942beaf406..6279bb8b60570 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -11,20 +11,23 @@ from compressed_tensors.quantization import (ActivationOrdering, QuantizationStrategy) import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa - FlashInferCutlassMoEPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_kernel, - flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) @@ -58,15 +61,46 @@ __all__ = [ class CompressedTensorsMoEMethod(FusedMoEMethodBase): + def __init_(self, moe: FusedMoEConfig): + super().__init__(moe) + @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - layer: torch.nn.Module, + layer: torch.nn.Module ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. - weight_quant = quant_config.target_scheme_map["Linear"].get("weights") - input_quant = quant_config.target_scheme_map["Linear"].get( + # Check if a using "Linear" to select scheems + if "Linear" in quant_config.target_scheme_map: + matched_target = "Linear" + else: + # May have instead defined the linear layers in the fused model + + fused_layers = [ + "re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*" + ] + current_scheme = None + for fused_layer in fused_layers: + # Check if one of the fused layers are defined in quant_config + matched_target = find_matched_target( + layer_name=fused_layer, + module=layer, + targets=quant_config.target_scheme_map.keys(), + fused_mapping=quant_config.packed_modules_mapping) + + # Only valid if down_proj, gate_proj, and up_proj + # are mapped to the same quant scheme in the quant_config + if current_scheme is None: + current_scheme = quant_config.target_scheme_map.get( + matched_target) + else: + assert current_scheme == quant_config.target_scheme_map.get( + matched_target) + + weight_quant = quant_config.target_scheme_map[matched_target].get( + "weights") + input_quant = quant_config.target_scheme_map[matched_target].get( "input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): @@ -81,18 +115,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): "WNA16MoE is not supported with actorder=group/dynamic." ) logger.info_once("Using CompressedTensorsWNA16MoEMethod") - return CompressedTensorsWNA16MoEMethod(quant_config) + return CompressedTensorsWNA16MoEMethod(quant_config, + layer.moe_config) else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") - return CompressedTensorsWNA16MarlinMoEMethod(quant_config) + return CompressedTensorsWNA16MarlinMoEMethod( + quant_config, layer.moe_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4MoeMethod() + return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant)): - return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + return CompressedTensorsW8A8Fp8MoEMethod(quant_config, + layer.moe_config) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8MoEMethod(quant_config) + return CompressedTensorsW8A8Int8MoEMethod(quant_config, + layer.moe_config) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") @@ -100,15 +138,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): - def __init__(self): + def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) + super().__init__(moe) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 - self.fused_experts = None # type: ignore[assignment] + self.layer = layer def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -237,13 +276,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): return # swizzle weight scales - layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale( layer.w13_weight_scale), - requires_grad=False) + requires_grad=False) - layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale( layer.w2_weight_scale), - requires_grad=False) + requires_grad=False) # w13 w13_input_global_scale = layer.w13_input_global_scale.max( @@ -265,19 +304,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): layer.w2_input_scale_quant = torch.nn.Parameter( (layer.w2_input_global_scale), requires_grad=False) - def maybe_swap_experts_impl(self, moe_parallel_config): + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: if not self.allow_flashinfer: - return - self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config) + return super().maybe_make_prepare_finalize(moe) - def select_gemm_impl(self, prepare_finalize, moe): + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe, + a1_gscale=self.layer.w13_input_scale_quant, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return the appropriate GEMM experts implementation.""" - assert moe is not None and prepare_finalize is not None - from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 - select_nvfp4_gemm_impl) - - return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) + experts = select_nvfp4_gemm_impl( + moe, + g1_alphas=self.layer.g1_alphas, + g2_alphas=self.layer.g2_alphas, + a1_gscale=self.layer.w13_input_scale_quant, + a2_gscale=self.layer.w2_input_scale_quant, + allow_flashinfer=self.allow_flashinfer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def apply( self, @@ -301,6 +357,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsW4A4MoeMethod` yet.") @@ -317,6 +375,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, ) if self.use_marlin: @@ -340,15 +399,49 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): # FlashInfer fused experts path if self.fused_experts is not None: - return flashinfer_fp4_cutlass_moe_forward( - self.fused_experts, - layer, - x, - topk_weights, - topk_ids, + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + elif self.allow_flashinfer: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4) + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + return flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -364,8 +457,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, + w1_blockscale=layer.w13_weight_scale, + w2_blockscale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, @@ -376,7 +469,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - device=x.device, apply_router_weight_on_input=apply_router_weight_on_input).to( x.dtype) @@ -384,15 +476,16 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( "weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( "input_activations") - self.topk_indices_dtype = None per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR and self.input_quant.strategy @@ -429,7 +522,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.weight_quant, self.input_quant) self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) - self.fused_experts = None # type: ignore[assignment] self.disable_expert_map = False def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -607,6 +699,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts + if self.use_cutlass: + device = layer.w13_weight.device + # ab_strides1 and c_strides2 are the same + self.ab_strides1_c_strides2 = torch.full( + (layer.local_num_experts, ), + layer.hidden_size, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full( + (layer.local_num_experts, ), + layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full( + (layer.local_num_experts, ), + 2 * layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, @@ -614,25 +725,39 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ) -> FusedMoEPermuteExpertsUnpermute: # cutlass path if self.use_cutlass: - from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 + from vllm.model_executor.layers.fused_moe import ( + CutlassBatchedExpertsFp8, CutlassExpertsFp8) - use_batched_format = (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts) + experts: FusedMoEPermuteExpertsUnpermute num_dispatchers = prepare_finalize.num_dispatchers() - num_experts = (moe.num_local_experts - if use_batched_format else moe.num_experts) - logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) - - experts = CutlassExpertsFp8( - num_experts, - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - num_dispatchers=num_dispatchers, - use_batched_format=use_batched_format, - ) + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + logger.debug("CutlassBatchedExpertsFp8(%s)", + self.__class__.__name__) + experts = CutlassBatchedExpertsFp8( + moe.num_local_experts, + num_dispatchers, + moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, + ) + else: + logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) + experts = CutlassExpertsFp8( + moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, + ) self.disable_expert_map = (num_dispatchers > 1 or not experts.supports_expert_map()) @@ -754,6 +879,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_map=None if self.disable_expert_map else expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) @@ -834,9 +963,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( "weights") @@ -934,6 +1065,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -951,7 +1084,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( hidden_states=x, @@ -975,9 +1109,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -1233,6 +1369,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -1251,7 +1389,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, @@ -1279,9 +1418,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -1459,6 +1600,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsWNA16MoEMethod` yet.") @@ -1475,7 +1618,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 734fa603ba7b9..cac65cca5093f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,6 +3,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) @@ -21,5 +22,6 @@ __all__ = [ "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int" + "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", + "CompressedTensorsW4A8Fp8" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 63bfe565b1211..dedd681f15ded 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 run_nvfp4_emulations) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + swizzle_blockscale) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): weight_loader=weight_loader) layer.register_parameter("input_global_scale", input_global_scale) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer) -> None: global_input_scale = layer.input_global_scale.max().to(torch.float32) @@ -133,13 +112,12 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): torch.uint8), epilogue_tile_m).reshape( weight_scale.shape).view(torch.float8_e4m3fn)) - layer.weight_scale_swizzled = Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_packed = Parameter(weight, requires_grad=False) else: - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale = Parameter(swizzled_weight_scale, + requires_grad=False) layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False) @@ -157,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): x=x, input_global_scale=layer.input_global_scale, weight=layer.weight_packed, - weight_scale_swizzled=layer.weight_scale_swizzled, + weight_scale_swizzled=layer.weight_scale, weight_global_scale=layer.weight_global_scale) if bias is not None: out = out + bias @@ -170,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) mm_args = (x_fp4, layer.weight_packed, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, output_dtype) + layer.weight_scale, layer.alpha, output_dtype) if self.backend == "flashinfer-trtllm": out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") elif self.backend == "flashinfer-cutlass": diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py new file mode 100644 index 0000000000000..3d9827058803e --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import ActivationOrdering + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +# yapf: enable +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A8Fp8"] +W4A8_SUPPORTED_TYPES_MAP = { + 4: scalar_types.int4, +} +W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.symmetric = symmetric + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size != 128 or self.strategy != "group": + raise ValueError("W4A8 kernels require group quantization " \ + "with group size 128") + + if num_bits not in W4A8_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + # hopper + return 90 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=torch.float8_e4m3fn, # always use fp8(e4m3) + group_size=self.group_size, + zero_points=not self.symmetric, + has_g_idx=self.has_g_idx, + out_type=params_dtype + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW4A8Fp8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + # TODO(czhu): allocate the packed fp8 scales memory here? + # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=torch.float8_e4m3fn, + ) + } + + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + # per-channel scales + weight_chan_scale = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("weight_chan_scale", weight_chan_scale) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 47eca80609e0e..3e43caa4cbf72 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -6,7 +6,8 @@ from typing import Any, Callable, Optional import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig): if isinstance(layer, LinearBase): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - return ExpertsInt8MoEMethod(self) + return ExpertsInt8MoEMethod(self, layer.moe_config) return None class ExpertsInt8MoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: ExpertsInt8Config): + def __init__( + self, + quant_config: ExpertsInt8Config, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `ExpertsInt8MoEMethod` yet.") @@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index dbd5234286952..be358cfa949f0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools from typing import TYPE_CHECKING, Any, Callable, Optional import torch @@ -10,6 +9,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -24,8 +24,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, - rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) + FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, + flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -45,8 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -126,6 +128,10 @@ class Fp8Config(QuantizationConfig): ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) + if not ignored_layers: + ignored_layers = cls.get_from_keys_or(config, + ["modules_to_not_convert"], + None) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, @@ -142,7 +148,7 @@ class Fp8Config(QuantizationConfig): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self) + return Fp8MoEMethod(self, layer) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None @@ -216,8 +222,7 @@ class Fp8LinearMethod(LinearMethodBase): self.fp8_linear = Fp8LinearOp( act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, - cutlass_fp8_supported=cutlass_fp8_supported()) + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, @@ -369,6 +374,8 @@ class Fp8LinearMethod(LinearMethodBase): # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) + # layer.input_scale is None indicates dynamic quant and scale is + # computed from input. layer.input_scale = None # If checkpoint is fp8, handle that there are N scales for N @@ -419,7 +426,7 @@ class Fp8LinearMethod(LinearMethodBase): # On B200, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( @@ -479,17 +486,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config): - - from vllm.model_executor.layers.fused_moe import fused_experts + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(layer.moe_config) + self.layer = layer self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None - self.flashinfer_moe_enabled = False + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( - "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.") - self.flashinfer_moe_enabled = True + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + ) # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) @@ -529,14 +539,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): "CutlassBlockScaledGroupedGemm not supported on the current " "platform.") - self.topk_indices_dtype = None - self.fused_experts = functools.partial( # type: ignore - fused_experts, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm)) + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: + return super().maybe_make_prepare_finalize(moe) + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe, + layer=self.layer, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -685,7 +700,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale) - elif self.flashinfer_moe_enabled: + elif self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is # applied on different half for flashinfer vs vllm w13_weight = swap_w13_to_w31(layer.w13_weight.data) @@ -693,9 +708,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data - if not self.block_quant: - register_moe_scaling_factors(layer) - rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) else: w13_weight = layer.w13_weight.data w13_weight_scale_inv = layer.w13_weight_scale_inv.data @@ -721,7 +733,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used(): + if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ @@ -841,13 +853,24 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + if self.flashinfer_moe_backend is not None: + # NOTE: weights have to be swapped since the activation is + # applied on different half for flashinfer vs vllm + assert not self.block_quant + register_moe_scaling_factors(layer) + w13_weight = swap_w13_to_w31(layer.w13_weight.data) + if self.flashinfer_moe_backend == \ + FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) + layer.w13_weight.data = w13_weight.data + if self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) @@ -899,6 +922,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): per_act_token_quant=False, allow_deep_gemm=self.allow_deep_gemm, ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + experts = select_cutlass_fp8_gemm_impl( + moe, + self.layer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts else: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", @@ -937,25 +967,66 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if not self.flashinfer_moe_enabled: - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) + + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") + if self.block_quant: + assert (renormalize and use_grouped_topk + and custom_routing_function is None) + + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.quant_config.weight_block_size, + routed_scaling=1.0, + ) + else: + assert (not renormalize + and custom_routing_function is not None) + return apply_flashinfer_per_tensor_scale_fp8( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + apply_router_weight_on_input=apply_router_weight_on_input) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 @@ -995,46 +1066,40 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) - elif self.flashinfer_moe_enabled: - assert activation == 'silu' - assert scoring_func == 'sigmoid' - if self.block_quant: - assert (renormalize and use_grouped_topk - and custom_routing_function is None) - - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), - routing_bias=e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale_inv, + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert self.block_quant is None + assert (not renormalize and custom_routing_function is not None) + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.quant_config.weight_block_size, - routed_scaling=1.0, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, ) else: - assert (not renormalize - and custom_routing_function is not None) - return apply_flashinfer_per_tensor_scale_fp8( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=e_score_correction_bias, + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input) + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) else: - return self.fused_experts( + common_kwargs = dict( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1053,6 +1118,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): a2_scale=layer.w2_input_scale, ) + if self.fused_experts is not None: + return self.fused_experts(**common_kwargs) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + **common_kwargs, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm), + ) + class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 86da04c39989b..90222f2e3b0e5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -11,8 +11,10 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEConfig, FusedMoEMethodBase) -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -27,8 +29,10 @@ logger = init_logger(__name__) class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" - def __init__(self, ) -> None: + def __init__(self, + unquantized_modules: Optional[list[str]] = None) -> None: super().__init__() + self.unquantized_modules = unquantized_modules or [] def __repr__(self) -> str: return ("GGUFConfig()") @@ -54,14 +58,20 @@ class GGUFConfig(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): + if is_layer_skipped_gguf(prefix, self.unquantized_modules): + return UnquantizedLinearMethod() return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): return GGUFEmbeddingMethod(self) elif isinstance(layer, FusedMoE): - return GGUFMoEMethod(self) + return GGUFMoEMethod(self, layer.moe_config) return None +def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]): + return any(module_name in prefix for module_name in unquantized_modules) + + UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} STANDARD_QUANT_TYPES = { WeightType.Q4_0, @@ -445,7 +455,12 @@ class GGUFMoEMethod(FusedMoEMethodBase): quant_config: The GGUF quantization config. """ - def __init__(self, quant_config: GGUFConfig): + def __init__( + self, + quant_config: GGUFConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -525,6 +540,8 @@ class GGUFMoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ): + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `GGUFMoEMethod` yet.") @@ -545,7 +562,8 @@ class GGUFMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, topk_weights, topk_ids, layer.w13_qweight_type.weight_type, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3299221e3af37..c5d1e017014f3 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) @@ -56,7 +56,7 @@ def get_moe_quant_method( # Dynamic per module/layer rules may override base config override_config(cloned_config, prefix=prefix) - return moe_method_cls(cloned_config) + return moe_method_cls(cloned_config, layer.moe_config) return None @@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase): class GPTQMarlinMoEMethod(FusedMoEMethodBase): """MoE Marlin method with quantization.""" - def __init__(self, quant_config: GPTQMarlinConfig) -> None: + def __init__( + self, + quant_config: GPTQMarlinConfig, + moe: FusedMoEConfig, + ) -> None: + super().__init__(moe) self.quant_config = quant_config if self.quant_config.quant_type.size_bits == 4: self.quant_type = scalar_types.uint4b8 @@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `GPTQMarlinMoEMethod` yet.") @@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 07ecc096231a4..1280f5f1eadf7 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -20,6 +20,7 @@ class MPLinearLayerConfig: group_size: int zero_points: bool has_g_idx: bool + out_type: Optional[torch.dtype] = None class MPLinearKernel(ABC): diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index a5084f6ee92cd..4bcfcd04b3d8b 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 ConchLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501 + CutlassW4A8LinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501 Dynamic4bitLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 @@ -24,6 +26,7 @@ from vllm.platforms import current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ + CutlassW4A8LinearKernel, MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py new file mode 100644 index 0000000000000..9e23c0dd3595b --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class CutlassW4A8LinearKernel(MPLinearKernel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # dynamic per-tok fp8 activation quantization + self.quant_fp8 = QuantFP8(static=False, + group_shape=GroupShape.PER_TOKEN) + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cuda(): + return False, "CUTLASS only supported on CUDA" + + if not current_platform.is_device_capability(90): + return False, "CUTLASS W4A8 requires compute capability of 90 "\ + "(Hopper)" + + if c.act_type != torch.float8_e4m3fn: + return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations" + + if c.has_g_idx: + return False, "Act reordering not supported by CUTLASS W4A8" + + if c.zero_points: + return False, "Zero points not supported by CUTLASS W4A8" + + if c.weight_type != scalar_types.int4: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "CUTLASS W4A8, only supported int4" + + # TODO(czhu): support -1 (column-wise) + if c.group_size != 128: + return False, "Only group_size 128 is supported" + + in_features, out_features = c.partition_weight_shape + if in_features % 128 or out_features % 128: + return False, "K and N must be divisible by 128, got "\ + f"{c.partition_weight_shape}" + + if c.out_type != torch.bfloat16: + return False, "Only bfloat16 output type currently supported"\ + f"got {c.out_type=}" + + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + + # TODO(czhu): optimize speed/mem usage + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.cutlass_encode_and_reorder_int4b( + x.data.t().contiguous().t()) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous().to(torch.float8_e4m3fn) + x.data = ops.cutlass_pack_scale_fp8(x.data) + return x + + # Encode/reorder weights and pack scales + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + self._transform_param(layer, "weight_chan_scale", lambda x: x) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + w_ch_s = layer.weight_chan_scale + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + x_2d, act_scales = self.quant_fp8(x_2d) + output = ops.cutlass_w4a8_mm(a=x_2d, + b_q=w_q, + b_group_scales=w_s, + b_group_size=c.group_size, + a_token_scales=act_scales, + b_channel_scales=w_ch_s) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 18f5ce04fd355..2bc68ab3ebd18 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -6,6 +6,8 @@ from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( + CPUScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 @@ -18,7 +20,7 @@ from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { - PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py new file mode 100644 index 0000000000000..59d2b5bce962e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.model_executor.layers.utils import check_cpu_sgl_kernel +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class CPUScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cpu(): + return False, "CPUScaledMM requires running on CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = getattr(layer, self.w_q_name) + dtype = weight.dtype + N, K = weight.size() + if (current_platform.get_cpu_architecture() == CpuArchEnum.X86 + and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric + and check_cpu_sgl_kernel(N, K, dtype)): + self.linear_method = self._apply_weights_sgl + self.process_weights_for_sgl(layer) + else: + self.linear_method = self._apply_weights_onednn + self.process_weights_for_onednn(layer) + + def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Transpose to [K, N] for convenience + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # oneDNN kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + azp = (int8_traits.min - + range_min / scale).round().to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # Different from cutlass, oneDNN kernels only need the AZP adjustment + # term for dynamic quantization. And s_b should be folded into the + # term. Such as: + # s_a * s_b * [(A - zp_a)B] + bias = + # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = + # s_a * GEMM_output - s_a * zp_a * adj + bias + if not (self.config.input_symmetric + and self.config.is_static_input_scheme): + weight = getattr(layer, self.w_q_name) + weight_scale = getattr(layer, self.w_s_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) + azp_adj = azp_adj * weight_scale.squeeze() + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + weight = getattr(layer, self.w_q_name) + self.dnnl_handler = ops.create_onednn_scaled_mm( + weight, + getattr(layer, self.w_s_name), + torch.get_default_dtype(), + getattr(layer, self.i_s_name) is None, + not self.config.input_symmetric, + 32, + ) + # weight is prepacked and maintained by the dnnl_handler, + # release the original weight + setattr(layer, self.w_q_name, None) + del weight + + def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: + # WEIGHT + weight = getattr(layer, self.w_q_name) + packed_weight = torch.ops._C.convert_weight_packed(weight) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(packed_weight, requires_grad=False)) + + if layer.bias is not None: + bias = layer.bias + layer.register_parameter( + "bias_fp32", + torch.nn.Parameter(bias.float().data, requires_grad=False)) + + # WEIGHT SCALE + # CPU SGL kernels only support per-channel. + # For per-tensor quant, convert to the per-channel case. + weight_scale = getattr(layer, self.w_s_name) + if not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.linear_method( + layer, + x, + bias, + ) + + def _apply_weights_onednn( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + x_q, x_s, x_zp = ops.onednn_scaled_int8_quant( + x, i_s, i_zp, self.config.input_symmetric) + + m = x.size(0) + n = self.dnnl_handler.n + out = torch.empty((m, n), dtype=x.dtype) + ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, + bias) + + return out + + def _apply_weights_sgl( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_weight_params(layer) + return torch.ops._C.int8_scaled_mm_with_quant( + x, + w_q, + w_s, + layer.bias_fp32 if bias is not None else None, + x.dtype, + True, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 6ddd4a9ec4233..2f982f96b0d04 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -25,8 +25,8 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): def can_implement( cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: - if (not current_platform.is_cuda() and not current_platform.is_cpu()): - return False, "CutlassScaledMM requires running on CUDA or CPU." + if not current_platform.is_cuda(): + return False, "CutlassScaledMM requires running on CUDA." return True, None diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py deleted file mode 100644 index 18d1c13373df9..0000000000000 --- a/vllm/model_executor/layers/quantization/marlin.py +++ /dev/null @@ -1,263 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) - -logger = init_logger(__name__) - - -class MarlinConfig(QuantizationConfig): - """Config class for Marlin. - - Reference: https://github.com/IST-DASLab/marlin/tree/master - """ - - def __init__( - self, - group_size: int, - lm_head_quantized: bool, - ) -> None: - super().__init__() - - # Group size for the quantization. - self.group_size = group_size - self.lm_head_quantized = lm_head_quantized - if self.group_size != 128 and self.group_size != -1: - raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin, but got group_size of " - f"{self.group_size}") - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // 4 - - # Tile size used by marlin kernels. - self.tile_size = 16 - - # Min out_features dim - self.min_n_threads = 64 - - # Min in_features dim - self.min_k_threads = 128 - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = 16 - - # Permutation length used by the marlin kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return (f"MarlinConfig(group_size={self.group_size}, " - f"lm_head_quantized={self.lm_head_quantized})") - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "marlin" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - # Need to figure it out - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return ["quantize_config.json"] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": - group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(group_size, lm_head_quantized) - - @classmethod - def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" - or hf_quant_cfg.get("is_marlin_format", False)) - - is_valid_user_quant = (user_quant is None or user_quant == "gptq" - or user_quant == "marlin") - - if is_marlin_format and is_valid_user_quant: - msg = ("The model is serialized in {} format. Using {} kernel.". - format(cls.get_name(), cls.get_name())) - logger.info(msg) - return cls.get_name() - - return None - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["MarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): - return MarlinLinearMethod(self) - return None - - -class MarlinLinearMethod(LinearMethodBase): - """Linear method for Marlin. - - Args: - quant_config: The Marlin quantization config. - """ - - def __init__(self, quant_config: MarlinConfig): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - del output_size # Unused. - weight_loader = extra_weight_attrs["weight_loader"] - - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) - - # Determine if channelwise or not - input_groups = (1 if self.quant_config.group_size == -1 else - input_size_per_partition // - self.quant_config.group_size) - - weight_scale_args = { - "data": - torch.empty( - input_groups, - output_size_per_partition, - device="cuda", - dtype=params_dtype, - ), - "weight_loader": - weight_loader - } - if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) - else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel - - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) - - layer.register_parameter("B", qweight) - layer.register_parameter("s", scales) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B = Parameter(layer.B.data, requires_grad=False) - layer.s = Parameter(layer.s.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = layer.B - scales = layer.s - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = scales.shape[1] - - output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, - size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 22fbbab00e919..72864853f7e0c 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum from typing import Any, Callable, Optional, Union import torch @@ -12,7 +11,9 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -22,11 +23,14 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_kernel, - flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, - rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) + FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, + flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) @@ -47,11 +51,6 @@ QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] -class FlashinferMoeBackend(Enum): - TENSORRT_LLM = "TensorRT-LLM" - CUTLASS = "CUTLASS" - - class ModelOptFp8Config(QuantizationConfig): """Config class for ModelOpt FP8.""" @@ -177,7 +176,7 @@ class ModelOptFp8Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): - return ModelOptFp8MoEMethod(self) + return ModelOptFp8MoEMethod(self, layer) return None @@ -273,16 +272,52 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptFp8Config) -> None: + def __init__( + self, + quant_config: ModelOptFp8Config, + layer: torch.nn.Module, + ) -> None: + super().__init__(layer.moe_config) + self.layer = layer self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() - self.flashinfer_moe_enabled = False + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( - "Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.") - self.flashinfer_moe_enabled = True + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + ) + + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.fused_experts is not None or \ + self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: + return super().maybe_make_prepare_finalize(moe) + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe, + layer=self.layer, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + experts = select_cutlass_fp8_gemm_impl( + moe, + self.layer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def create_weights( self, @@ -426,11 +461,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), requires_grad=False) - if self.flashinfer_moe_enabled: + if self.flashinfer_moe_backend is not None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, - layer.w2_weight) register_moe_scaling_factors(layer) + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) def apply( self, @@ -458,8 +494,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") - if self.flashinfer_moe_enabled: - assert activation == 'silu' + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") assert not renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, @@ -484,7 +521,38 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, ) + + if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert not renormalize + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) return fused_experts( @@ -699,7 +767,7 @@ class ModelOptNvFp4Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): - return ModelOptNvFp4FusedMoE(self) + return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer) return None @@ -839,20 +907,18 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): torch.uint8), epilogue_tile_m).reshape( weight_scale.shape).view(torch.float8_e4m3fn)) - layer.weight_scale_swizzled = Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(swizzled_weight_scale, + requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) if self.backend == "marlin": prepare_fp4_layer_for_marlin(layer) del layer.alpha del layer.input_scale - del layer.weight_scale_swizzled def apply( self, @@ -883,14 +949,14 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): assert (x_fp4.dtype == torch.uint8) assert (layer.weight.dtype == torch.uint8) assert (x_blockscale.dtype == torch.float8_e4m3fn) - assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) + assert (layer.weight_scale.dtype == torch.float8_e4m3fn) assert (layer.alpha.dtype == torch.float32) mm_args = ( x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, + layer.weight_scale, layer.alpha, output_dtype, ) @@ -923,10 +989,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): quant_config: NVFP4 Quant Config """ - def __init__(self, quant_config: ModelOptNvFp4Config) -> None: - self.quant_config = quant_config + def __init__( + self, + quant_config: ModelOptNvFp4Config, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> None: from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) + super().__init__(moe) + self.quant_config = quant_config + self.layer = layer _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.allow_flashinfer = _nvfp4.allow_flashinfer @@ -934,45 +1007,42 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.flashinfer_moe_backend = None if self.allow_flashinfer: - flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND - if flashinfer_moe_backend == "throughput": - self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS - logger.info_once("Using FlashInfer CUTLASS kernels for " - "ModelOptNvFp4FusedMoE.") - elif flashinfer_moe_backend == "latency": - self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM - logger.info_once("Using FlashInfer TensorRT-LLM kernels for " - "ModelOptNvFp4FusedMoE.") - else: - allowed_backends = ["throughput", "latency"] - raise ValueError( - f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" - f" expected one of {allowed_backends}") + self.flashinfer_moe_backend = get_flashinfer_moe_backend() + logger.info_once( + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + " for ModelOptNvFp4FusedMoE.") - self.fused_experts: Optional[ - mk.FusedMoEModularKernel] = None # type: ignore[assignment] - - def maybe_swap_experts_impl( + def maybe_make_prepare_finalize( self, - moe_parallel_config: FusedMoEParallelConfig, - ): - if not self.allow_flashinfer: - return - self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config) + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if (self.allow_flashinfer and self.flashinfer_moe_backend + == FlashinferMoeBackend.CUTLASS): + prepare_finalize = ( + build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe, + a1_gscale=self.layer.w13_input_scale_quant, + )) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize - # This method update self.fused_experts - # only prepare_finalize is not None call select_gemm_impl - # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert - # when it's not called(TP case), we still have 2 kernels to use. - def select_gemm_impl(self, prepare_finalize, - moe) -> mk.FusedMoEPermuteExpertsUnpermute: + return super().maybe_make_prepare_finalize(moe) - assert moe is not None and prepare_finalize is not None - from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 - select_nvfp4_gemm_impl) - - return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + experts = select_nvfp4_gemm_impl( + moe, + g1_alphas=self.layer.g1_alphas, + g2_alphas=self.layer.g2_alphas, + a1_gscale=self.layer.w13_input_scale_quant, + a2_gscale=self.layer.w2_input_scale_quant, + allow_flashinfer=self.allow_flashinfer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def uses_weight_scale_2_pattern(self) -> bool: """ @@ -1248,16 +1318,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): "Weight Blockscale must be represented as FP8-E4M3") w13_blockscale_swizzled = swizzle_blockscale( layer.w13_weight_scale) - layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, - requires_grad=False) + layer.w13_weight_scale = Parameter(w13_blockscale_swizzled, + requires_grad=False) assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( "Expected weight_scale.dim(1) to be divisible by 16") assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Blockscale must be represented as FP8-E4M3") w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) - layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, - requires_grad=False) + layer.w2_weight_scale = Parameter(w2_blockscale_swizzled, + requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) @@ -1267,8 +1337,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): del layer.g2_alphas del layer.w13_input_scale_quant del layer.w2_input_scale_quant - del layer.w13_blockscale_swizzled - del layer.w2_blockscale_swizzled def apply( self, @@ -1362,7 +1430,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) if self.use_marlin: return torch.ops.vllm.fused_marlin_moe( @@ -1383,7 +1452,52 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_num_experts=global_num_experts, expert_map=expert_map) - if self.fused_experts is None: + if self.fused_experts is not None: + assert self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + out = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif (self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4) + + out = flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import ( @@ -1392,8 +1506,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, + w1_blockscale=layer.w13_weight_scale, + w2_blockscale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, @@ -1404,22 +1518,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - device=x.device, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - else: - assert self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - out = flashinfer_fp4_cutlass_moe_forward( - self.fused_experts, - layer, - x, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) return out diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index c5055a02fa3d5..364d1ac314d2d 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -7,7 +7,7 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig): else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): - return MoeWNA16Method(self) + return MoeWNA16Method(self, layer.moe_config) return None @@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase): quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __init__(self, quant_config: MoeWNA16Config): + def __init__( + self, + quant_config: MoeWNA16Config, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `MoeWNA16Method` yet.") @@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index dbe6c603c0625..bdeb169a4b97f 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -6,10 +6,10 @@ import torch from torch.nn.parameter import Parameter from vllm import envs +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) -from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - triton_kernel_moe_forward) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -26,12 +26,38 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, next_power_of_2, round_up) +from vllm.utils.flashinfer import has_flashinfer -if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): - # from flashinfer.fused_moe import cutlass_fused_moe - from flashinfer import (mxfp8_quantize, shuffle_matrix_a, - shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) +logger = init_logger(__name__) + + +def _should_use_flashinfer_mxfp4_bf16(): + """Determine if FlashInfer MXFP4 BF16 should be used.""" + # If explicitly set, respect the setting + if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 + + # Enable by default on SM100 if MXFP8 is not explicitly enabled + if (current_platform.is_device_capability(100) and has_flashinfer() + and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): + logger.info_once( + "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " + "For faster performance, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " + "though this may impact accuracy.") + return True + + return False + + +def _should_use_flashinfer_mxfp4_mxfp8(): + """Determine if FlashInfer MXFP4 MXFP8 should be used.""" + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + + +def should_use_flashinfer_mxfp4(): + return (_should_use_flashinfer_mxfp4_mxfp8() + or _should_use_flashinfer_mxfp4_bf16()) class Mxfp4Config(QuantizationConfig): @@ -82,17 +108,25 @@ class Mxfp4Config(QuantizationConfig): class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): - super().__init__() + super().__init__(moe) self.topk_indices_dtype = None self.moe = moe self.use_marlin = self._should_use_marlin() + self.max_capture_size = get_current_vllm_config( + ).compilation_config.max_capture_size + + if current_platform.is_device_capability(100) and not has_flashinfer(): + logger.warning_once( + "MXFP4 MoE is enabled on Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results.") def _should_use_marlin(self): if envs.VLLM_MXFP4_USE_MARLIN is not None: return envs.VLLM_MXFP4_USE_MARLIN if current_platform.is_cuda() and \ - not current_platform.has_device_capability(100): - if not current_platform.is_device_capability(90): + not current_platform.is_device_capability(100): + if not current_platform.has_device_capability(90): # marlin kernel has better performance on ampere return True if not has_triton_kernels(): @@ -138,8 +172,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.hidden_size = hidden_size layer.intermediate_size_per_partition = \ intermediate_size_per_partition_after_pad - elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + elif should_use_flashinfer_mxfp4(): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance @@ -230,8 +263,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer): if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) - elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + elif should_use_flashinfer_mxfp4(): + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False) @@ -478,17 +511,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): logical_replica_count), ( "MXFP4 are not supported with this configuration.") - if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + if should_use_flashinfer_mxfp4(): + from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe assert not self.moe.use_ep, ( "EP is not supported for flashinfer mxfp4 moe backend yet.") - if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: + if _should_use_flashinfer_mxfp4_bf16(): assert x.dtype == torch.bfloat16 x_quant = x x_scale = None else: x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 - x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *x.shape[:-1], -1) trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -517,9 +551,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self._get_tile_tokens_dim(x, top_k), 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize + tune_max_num_tokens=self.max_capture_size, )[0] return trtllm_gen_output else: + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_forward) return triton_kernel_moe_forward( hidden_states=x, w1=self.w13_weight_triton_tensor, diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py new file mode 100644 index 0000000000000..5b9fee69bb021 --- /dev/null +++ b/vllm/model_executor/layers/quantization/petit.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + +from typing import Any, Optional + +import regex as re +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.petit_utils import ( + apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +# Initialize logger for the module +logger = init_logger(__name__) + + +# Configuration class to support the NVFP4 quantized model +# generated by the ModelOpt quantization tool +class PetitNvFp4Config(QuantizationConfig): + """Config class for Petit FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: Optional[str] = None, + group_size: Optional[int] = None, + exclude_modules: Optional[list[str]] = None, + ) -> None: + self._check_hardware_support() + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning("Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change.") + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + def _check_hardware_support(self) -> None: + """ + Verifies that the current hardware is supported by the Petit backend. + This backend is specifically designed for AMD GPUs and is not + supported on the CUDA platform. + """ + # This check ensures the code is NOT running on an NVIDIA GPU. + if current_platform.is_cuda(): + raise ValueError( + "The 'petit' quantization backend is designed for AMD GPUs " + "and is not supported on the CUDA platform. For NVIDIA GPUs, " + "please use a different quantization method such as FP8, AWQ, " + "or GPTQ.") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "petit_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # Petit supports the gfx90a and gfx942 GPUs + return 90 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": + qc = cls.get_from_keys(config, ["quantization"]) + + quant_method_raw = qc.get("quant_algo") + if not isinstance(quant_method_raw, str) or not quant_method_raw: + raise ValueError( + "Missing or invalid 'quant_algo' in quantization config.") + quant_method = quant_method_raw.upper() + + group_size_raw = qc.get("group_size") + if not isinstance(group_size_raw, int): + raise ValueError( + "Missing or invalid 'group_size' (int) in hf_quant_config.json." + ) + group_size = group_size_raw + + verify_petit_nvfp4_supported(quant_method, group_size) + + kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto" + if not isinstance(kv_cache_quant_algo_raw, str): + raise ValueError( + "'kv_cache_quant_algo' must be a string if provided.") + kv_cache_quant_algo = kv_cache_quant_algo_raw + + exclude_raw = qc.get("exclude_modules", []) + if exclude_raw is None: + exclude_modules: list[str] = [] + elif isinstance(exclude_raw, list) and all( + isinstance(x, str) for x in exclude_raw): + exclude_modules = exclude_raw + else: + raise ValueError( + "'exclude_modules' must be a list[str] (or omitted).") + + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + + return cls( + is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo=kv_cache_quant_algo, + group_size=group_size, + exclude_modules=exclude_modules, + ) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + if not current_platform.is_rocm(): + return None + + qc = hf_quant_cfg.get("quantization", hf_quant_cfg) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"): + return cls.get_name() # "petit_nvfp4" + return None + + @classmethod + def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool: + qc = quant_config.get("quantization", quant_config) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + return algo == "NVFP4" + + def is_layer_excluded(self, prefix: str, + exclude_modules: list[str]) -> bool: + for pattern in exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True + return False + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + exclude = self.require_exclude_modules() + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, exclude) or self.is_layer_excluded( + prefix, exclude): + return UnquantizedLinearMethod() + return PetitNvFp4LinearMethod(self) + elif isinstance(layer, Attention): + return PetitFp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> list[str]: + return [] + + def require_group_size(self) -> int: + if self.group_size is None: + logger.warning("group_size not set; defaulting to 16 for NVFP4.") + return 16 + return self.group_size + + def require_kv_cache_quant_algo(self) -> str: + return self.kv_cache_quant_algo or "auto" + + def require_exclude_modules(self) -> list[str]: + return list(self.exclude_modules or []) + + +class PetitFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + super().__init__(quant_config) + + +class PetitNvFp4LinearMethod(LinearMethodBase): + """Linear method for NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + |Tensor Name | datatype | shape | + |----------------------------------------------------| + |input_scale | torch.float32 | scalar | + |weight | NVFP4(SE2M1) | [1, X, y/2] | + |weight_scale | FP8-E4M3 | [X, Y] | + |weight_scale_2 | torch.float32 | scalar | + + The weights are quantized per block of 16 elements. + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + if input_size_per_partition % 16 != 0: + raise ValueError("Unsupported model when in features size is " + "not multiple of 16") + + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype) + + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 data is packed in one uint8 in the input dimension + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + group_size = self.quant_config.require_group_size() + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_scale_2 = layer.input_scale.max().to(torch.float32) + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, + requires_grad=False) + + prepare_nvfp4_layer_for_petit(layer) + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_petit_nvfp4_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index d11cba2caba88..466fd5fba7685 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -97,8 +97,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = Fp8LinearOp( act_quant_static=False, - cutlass_fp8_supported=False, - act_quant_group_shape=GroupShape.PER_TOKEN) + act_quant_group_shape=GroupShape.PER_TOKEN, + force_fp8_e4m3fnuz=True) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py deleted file mode 100644 index 25978cb13b3ab..0000000000000 --- a/vllm/model_executor/layers/quantization/qqq.py +++ /dev/null @@ -1,275 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) - -logger = init_logger(__name__) - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - - -class QQQConfig(QuantizationConfig): - """Config class for QQQ - - Reference: https://arxiv.org/pdf/2406.09904 - """ - - def __init__( - self, - weight_bits: int, - group_size: int, - is_sym: bool = True, - ) -> None: - super().__init__() - self.weight_bits = weight_bits - self.group_size = group_size - self.is_sym = is_sym - - # Verify - if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: - raise ValueError( - f"QQQ does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " - "are supported.") - if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: - raise ValueError( - f"QQQ does not support group_size = {self.group_size}. " - f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " - "are supported.") - if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: - raise ValueError( - f"QQQ does not support is_sym = {self.is_sym}. " - f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.") - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // self.weight_bits - - # Tile size used by QQQ kernels. - self.tile_size = MARLIN_QQQ_TILE - - # Min out_features dim - self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N - - # Min in_features dim - self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = MARLIN_QQQ_MAX_PARALLEL - - # Permutation length used by the QQQ kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return "QQQConfig(weight_bits={}, group_size={})".format( - self.weight_bits, self.group_size) - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "qqq" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - """List of filenames to search for in the model directory.""" - return [ - "quant_config.json", - "quantize_config.json", - ] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "QQQConfig": - weight_bits = cls.get_from_keys(config, ["wbits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - return cls(weight_bits, group_size) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QQQLinearMethod"]: - if isinstance(layer, LinearBase): - return QQQLinearMethod(self) - return None - - -class QQQLinearMethod(LinearMethodBase): - """Linear method for QQQ. - - Args: - quant_config: The QQQ quantization config. - """ - - def __init__(self, quant_config: QQQConfig): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight_loader = extra_weight_attrs["weight_loader"] - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) - - s_channel = ChannelQuantScaleParameter(data=torch.empty( - 1, - output_size_per_partition, - device="cuda", - dtype=torch.float, - ), - weight_loader=weight_loader, - output_dim=1) - - if self.quant_config.group_size == -1: - s_group_data = torch.tensor( - [], - device="cuda", - dtype=torch.half, - ) - else: - s_group_data = torch.empty( - input_size_per_partition // self.quant_config.group_size, - output_size_per_partition, - device="cuda", - dtype=torch.half, - ) - - s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} - - if self.quant_config.group_size == -1: - s_group = BasevLLMParameter(**s_group_attr) - else: - s_group = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **s_group_attr) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel - - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) - - layer.register_parameter("B", qweight) - layer.register_parameter("s_channel", s_channel) - layer.register_parameter("s_group", s_group) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B = Parameter(layer.B.data, requires_grad=False) - layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) - layer.s_group = Parameter(layer.s_group.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = layer.B - s_ch = layer.s_channel - s_group = layer.s_group - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = s_ch.shape[1] - - x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) - - output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 6f69210d0861c..58f56c6381b31 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -7,7 +7,8 @@ import torch from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( OCP_MX_BLOCK_SIZE) @@ -25,6 +26,9 @@ __all__ = [ class QuarkMoEMethod(FusedMoEMethodBase): + def __init__(self, moe: FusedMoEConfig): + super().__init__(moe) + @staticmethod def get_moe_method( quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 @@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase): input_config = layer_quant_config.get("input_tensors") if quant_config._is_fp8_w8a8(weight_config, input_config): - return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + return QuarkW8A8Fp8MoEMethod(weight_config, input_config, + module.moe_config) elif quant_config._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) + return QuarkW4A4MXFp4MoEMethod(weight_config, input_config, + module.moe_config) else: raise RuntimeError("Unsupported FusedMoe scheme") class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): - def __init__(self, weight_config: dict[str, Any], input_config: dict[str, - Any]): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config @@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") @@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, @@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): - def __init__(self, weight_config: dict[str, Any], input_config: dict[str, - Any]): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config @@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( @@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) out = fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index cceaf9857c40f..8bdb50e07b137 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -10,7 +10,8 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig): if isinstance(layer, LinearBase): return RTNLinearMethod(self) elif isinstance(layer, FusedMoE): - return RTNMoEMethod(self) + return RTNMoEMethod(self, layer.moe_config) return None @@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase): class RTNMoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: RTNConfig): + def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `RTNMoEMethod` yet.") @@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) weight_bits = self.quant_config.weight_bits group_size = self.quant_config.group_size diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..fbca5ce05d018 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/README.md b/vllm/model_executor/layers/quantization/utils/configs/README.md new file mode 100644 index 0000000000000..1110ced4fa063 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/README.md @@ -0,0 +1,3 @@ +# Quantization Kernel Config + +Use scripts under `benchmarks/kernels/` to generate these config files. diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 8ef91eeed406f..f5d7c57fe2a87 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -3,33 +3,30 @@ """Utility helpers for NVFP4 + FlashInfer fused-MoE path""" from __future__ import annotations -from typing import Optional - import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) + FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FlashInferCutlassMoEPrepareAndFinalize) from vllm.platforms import current_platform - -logger = init_logger(__name__) +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe __all__ = [ "is_flashinfer_fp4_cutlass_moe_available", "reorder_w1w3_to_w3w1", - "build_flashinfer_fp4_cutlass_moe_kernel", - "flashinfer_fp4_cutlass_moe_forward", + "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] def is_flashinfer_fp4_cutlass_moe_available() -> bool: """Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda() + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and current_platform.is_cuda() and current_platform.is_device_capability(100)) @@ -49,105 +46,33 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor, dim=dim).contiguous()) -def build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel: - """Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel""" - experts = FlashInferExperts( - use_nvfp4_w4a4=True, - use_dp=moe_parallel_config.dp_size > 1, - ep_rank=moe_parallel_config.ep_rank, - ep_size=moe_parallel_config.ep_size, - tp_rank=moe_parallel_config.tp_rank, - tp_size=moe_parallel_config.tp_size, - ) - logger.debug_once("FlashInferExperts (util)") - return mk.FusedMoEModularKernel( - FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8), - experts, - ) - - -def flashinfer_fp4_cutlass_moe_forward( - fused_experts: mk.FusedMoEModularKernel, - layer: torch.nn.Module, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, -) -> torch.Tensor: - """Common forward wrapper for FlashInfer NV-FP4 fused-MoE""" - - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, - layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!") - - a1_gscale = layer.w13_input_scale_quant - a2_gscale = layer.w2_input_scale_quant - - extra_expert_args = { - "g1_alphas": layer.g1_alphas, - "g2_alphas": layer.g2_alphas, - # Avoid confusion with a1_scale and a2_scale - # where are batch size related. - "a1_gscale": a1_gscale, - "a2_gscale": a2_gscale, - "out_dtype": x.dtype, - } - extra_prepare_args = { - "use_dp": layer.dp_size > 1, - "local_tokens": x.shape[0], - "a1_gscale": a1_gscale, - } - extra_finalize_args = { - "use_dp": layer.dp_size > 1, - "local_tokens": x.shape[0], - } - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, - apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args, - extra_prepare_args=extra_prepare_args, - extra_finalize_args=extra_finalize_args, - ) +def build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe: FusedMoEConfig, + a1_gscale: torch.Tensor, +) -> mk.FusedMoEPrepareAndFinalize: + """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" + use_dp = moe.moe_parallel_config.dp_size > 1 + return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale) def select_nvfp4_gemm_impl( - allow_flashinfer: bool, - moe, # FusedMoEConfig - logger): + moe: FusedMoEConfig, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + allow_flashinfer: bool, +) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" - # lazy import - from vllm.distributed import get_ep_group - - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - if allow_flashinfer: - flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND - if flashinfer_backend != "throughput": - raise ValueError( - f"Only throughput backend is supported for FlashInferExperts, " - f"but got {flashinfer_backend}.") - logger.debug_once( - "Initializing FlashInferExperts with throughput backend.") return FlashInferExperts( - use_nvfp4_w4a4=True, - use_dp=moe.moe_parallel_config.dp_size > 1, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + out_dtype=moe.in_dtype, + quant_dtype="nvfp4", ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 278ee5232f47e..9889808f0760f 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -1,9 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import Optional import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) + +logger = init_logger(__name__) + + +class FlashinferMoeBackend(Enum): + TENSORRT_LLM = "TensorRT-LLM" + CUTLASS = "CUTLASS" + def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): @@ -144,3 +161,98 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: layer.register_parameter( 'output2_scales_scalar', torch.nn.Parameter(output2_scales, requires_grad=False)) + layer.register_parameter( + 'w2_input_scale_inv', + torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False)) + + +def build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe: Optional[FusedMoEConfig], + layer: torch.nn.Module, +) -> mk.FusedMoEPrepareAndFinalize: + """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" + use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False + return FlashInferCutlassMoEPrepareAndFinalize( + use_dp, a1_gscale=layer.w13_input_scale) + + +def select_cutlass_fp8_gemm_impl( + moe: Optional[FusedMoEConfig], + layer: torch.nn.Module, + out_dtype: Optional[torch.dtype] = None, +) -> mk.FusedMoEPermuteExpertsUnpermute: + """Return a GEMM *experts* implementation for fused-MoE layers""" + + from vllm.model_executor.models.llama4 import Llama4MoE + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ + "FusedMoE flashinfer kernels are only supported for Llama4" + + if moe is not None: + return FlashInferExperts( + g1_alphas=layer.output1_scales_gate_scalar, + g2_alphas=layer.output2_scales_scalar, + a1_gscale=layer.w13_input_scale, + a2_gscale=layer.w2_input_scale_inv, + out_dtype=moe.in_dtype, + quant_dtype=torch.float8_e4m3fn, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) + + assert out_dtype is not None, ( + "If moe config is None, out_dtype must be passed") + return FlashInferExperts( + g1_alphas=layer.output1_scales_gate_scalar, + g2_alphas=layer.output2_scales_scalar, + a1_gscale=layer.w13_input_scale, + a2_gscale=layer.w2_input_scale_inv, + out_dtype=out_dtype, + quant_dtype=torch.float8_e4m3fn, + ) + + +def flashinfer_cutlass_moe_fp8( + hidden_states: torch.Tensor, + layer: torch.nn.Module, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + fused_experts = mk.FusedMoEModularKernel( + build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None, + layer=layer), + select_cutlass_fp8_gemm_impl(moe=None, + layer=layer, + out_dtype=hidden_states.dtype)) + + return fused_experts( + hidden_states, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +def get_flashinfer_moe_backend() -> FlashinferMoeBackend: + flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND + if flashinfer_moe_backend == "throughput": + return FlashinferMoeBackend.CUTLASS + elif flashinfer_moe_backend == "latency": + return FlashinferMoeBackend.TENSORRT_LLM + + allowed_backends = ["throughput", "latency"] + raise ValueError( + f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" + f" expected one of {allowed_backends}") diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2fb7ef29e4684..7b324dce3c367 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -19,8 +19,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -108,19 +109,6 @@ def dispatch_w8a8_blockscale_func( return w8a8_block_fp8_matmul -def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): - """ - Check if DeepGEMM should be used based on the output dtype and weight shape. - DeepGEMM is only supported for bfloat16 output dtype and weights with shape - divisible by 128. - """ - - return (current_platform.is_cuda() - and current_platform.is_device_capability(90) and has_deep_gemm() - and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 - and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) - - # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( @@ -139,7 +127,7 @@ def apply_w8a8_block_fp8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm(output_dtype, weight): + if should_use_deepgemm_for_fp8_linear(output_dtype, weight): input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] @@ -150,7 +138,9 @@ def apply_w8a8_block_fp8_linear( column_major_scales=True, ) + # ensure DeepGEMM-backed custom op is registered before use import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( q_input, weight, @@ -395,7 +385,7 @@ def per_token_group_quant_fp8( scaling factor. """ if use_ue8m0 is None: - use_ue8m0 = is_blackwell_deep_gemm_e8m0_used() + use_ue8m0 = is_deep_gemm_e8m0_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index 8a64bebae04c9..0000000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: list[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: list[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: list[int] = [] - for i in range(32): - perm1: list[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index deeb69bcad0ec..48f9cc3737e47 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -61,14 +61,14 @@ def _can_support_mxfp4(use_grouped_topk: bool = False, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, scoring_func: str = "softmax", - activation: str = "swiglu_oai", + activation: str = "swigluoai", expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None): return not (use_grouped_topk or topk_group or num_expert_group or expert_map or custom_routing_function or e_score_correction_bias or apply_router_weight_on_input - or scoring_func != "softmax" or activation != "swiglu_oai" + or scoring_func != "softmax" or activation != "swigluoai" or expert_load_view or logical_to_physical_map or logical_replica_count) diff --git a/vllm/model_executor/layers/quantization/utils/petit_utils.py b/vllm/model_executor/layers/quantization/utils/petit_utils.py new file mode 100644 index 0000000000000..00d3def1db81e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/petit_utils.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +import torch + +# TYPE_CHECKING is used for static type analysis to prevent circular imports. +if TYPE_CHECKING: + from types import ModuleType + +# 1. Create a global variable as a placeholder for the module +_petit_kernel: Optional["ModuleType"] = None + +_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with " + "`pip install petit-kernel`.") + + +def _import_petit_kernel() -> "ModuleType": + """ + A helper function to handle the lazy import. + The first time this function is called, it will import the petit_kernel + library and store it in the global _petit_kernel variable. + Subsequent calls will return the already-loaded module directly. + """ + global _petit_kernel + if _petit_kernel is not None: + return _petit_kernel + + try: + import petit_kernel + _petit_kernel = petit_kernel + return _petit_kernel + except ImportError: + # The 'from None' syntax prevents chaining the original ImportError, + # making the traceback cleaner. + raise ImportError(_PETIT_INSTALL_MSG) from None + + +# The _require_petit function can now be a simple alias for consistency. +_require_petit = _import_petit_kernel + + +def _check_petit_nvfp4_supported( + quant_method: str, + group_size: Optional[int]) -> tuple[bool, Optional[str]]: + if quant_method != "NVFP4": + return ( + False, + ("Petit currently only supports: NVFP4 quantizations in sglang. " + "Please check the `hf_quant_config.json` file for your model's " + "quant configuration."), + ) + if group_size is not None and group_size != 16: + return ( + False, + "Petit currently only supports: group_size=16 quantizations.", + ) + return (True, None) + + +def verify_petit_nvfp4_supported(quant_method: str, + group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported( + quant_method, group_size) + if not supported: + assert error_msg is not None + raise ValueError(error_msg) + + +def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + # 2. Call _import_petit_kernel() to trigger (or get) the import. + petit_kernel = _import_petit_kernel() + + # Repack weights to petit format + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + qweight = layer.weight.view(torch.int32).contiguous() + + # 3. Call functions through the imported module variable. + petit_qweight = petit_kernel.repack_nvfp4(qweight, + size_n=part_size_n, + size_k=part_size_k) + layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) + + # Permute scales + weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale, + size_k=part_size_k, + size_n=part_size_n) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + +def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # Trigger (or get) the import here as well. + petit_kernel = _import_petit_kernel() + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + # TODO: Use auto-tuning to find the performant solution_id + # Call the function via the module variable. + output = petit_kernel.mul_nvfp4_a16( + a=reshaped_x, + b=weight, + s=weight_scale, + global_scale=weight_scale_2, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + solution_id=-1, + ) + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 428e9e99aa881..6154fca2e416d 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -2,18 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" from collections.abc import Mapping +from dataclasses import dataclass from types import MappingProxyType from typing import ClassVar, NamedTuple, Optional import numpy import torch +from torch import fx from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 -from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + # Use proxy as NamedTuple direct subclasses cannot have static members class _GroupShape(NamedTuple): @@ -36,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) +@dataclass(frozen=True) +class ScaleDesc: + """ + Class for describing a single quantization scaling factor. + dtype: data type of the scale + static: static scale if True, dynamic if False + group_shape: group shape of the scale + """ + dtype: torch.dtype + static: bool + group_shape: GroupShape + + def __str__(self): + group_shape = ('per_tensor' + if self.group_shape == GroupShape.PER_TENSOR else + ('per_token' if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape))) + + return (f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'static' if self.static else 'dynamic'},{group_shape}") + + +@dataclass(frozen=True) +class QuantKey: + """ + Class for identifying the type of quantization. + dtype: quantized data type + scale: scale descriptor + scale2: second-level scale descriptor + symmetric: symmetric if True, asymmetric if False + """ + dtype: torch.dtype + scale: ScaleDesc + scale2: Optional[ScaleDesc] = None + symmetric: bool = True + + def __str__(self): + scale2_str = f"scale2({self.scale2})," if self.scale2 else "" + return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," + f"scale({self.scale}),{scale2_str}" + f"{'a' if not self.symmetric else ''}symmetric)") + + +kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR) +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True) + +kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) + +kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) + +kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) +kNvfp4Quant = QuantKey(FP4_DTYPE, + scale=kNvfp4GroupScale, + scale2=kStaticTensorScale) + + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent @@ -386,89 +447,6 @@ def gptq_quantize_weights(w: torch.Tensor, return w_ref, w_q, w_s, g_idx, rand_perm -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ - f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / - s_channel).to(dtype=torch.half) - else: - max_q_val = 2**(num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= (2**(8 - num_bits)) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device @@ -637,8 +615,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() if scale_ndim == 2: - return swizzled.reshape(M, K) - return swizzled.reshape(B, M, K) + return swizzled.reshape(M_padded, K_padded) + return swizzled.reshape(B, M_padded, K_padded) def cutlass_fp4_supported() -> bool: diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index ddb50968904d1..5333bbd310ff9 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -13,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -121,6 +123,9 @@ def requantize_with_max_scale( if unfused_module_in_checkpoint: start = 0 for idx, logical_width in enumerate(logical_widths): + # Skip any component with zero width. + if logical_width == 0: + continue end = start + logical_width weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) @@ -153,13 +158,23 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, return output.view(*output_shape) -def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: +def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: list, **kwargs) -> torch.Tensor: + + return flashinfer_scaled_fp8_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + +def rocm_per_tensor_w8a8_scaled_mm_impl( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: @@ -172,10 +187,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a=scale_a, scale_b=scale_b, bias=bias) + return output + +def rocm_per_tensor_w8a8_scaled_mm_fake( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: + return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), + dtype=out_dtype) + + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) +direct_register_custom_op( + op_name="rocm_per_tensor_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_w8a8_scaled_mm_impl, + mutates_args=[], + fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, + dispatch_key=current_platform.dispatch_key, +) + + def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, @@ -202,8 +245,8 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: + input_2d: torch.Tensor, output_shape: list, + **kwargs) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. # For now it has only been validated on ROCm platform. @@ -274,16 +317,22 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( - cutlass_fp8_supported: bool, per_tensor_weights: bool, + preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool) -> Callable[..., torch.Tensor]: - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if cutlass_fp8_supported: - return cutlass_w8a8_scaled_mm if per_tensor_weights and per_tensor_activations: - if current_platform.is_rocm(): + if preferred_backend == "rocm": return rocm_per_tensor_w8a8_scaled_mm + if preferred_backend == "flashinfer": + return flashinfer_w8a8_scaled_mm + if preferred_backend == "cutlass": + return cutlass_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if preferred_backend == "cutlass" or preferred_backend == "flashinfer": + return cutlass_w8a8_scaled_mm + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) if not per_tensor_weights and not per_tensor_activations \ and USE_ROWWISE_TORCH_SCALED_MM: @@ -305,10 +354,20 @@ class Fp8LinearOp: def __init__(self, act_quant_static: bool, - cutlass_fp8_supported: bool = cutlass_fp8_supported(), act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: Optional[bool] = None): - self.cutlass_fp8_supported = cutlass_fp8_supported + pad_output: Optional[bool] = None, + force_fp8_e4m3fnuz: bool = False): + if current_platform.is_rocm(): + self.preferred_backend = "rocm" + elif current_platform.is_cuda( + ) and not force_fp8_e4m3fnuz and cutlass_fp8_supported(): + if has_flashinfer() and current_platform.has_device_capability( + 100): + self.preferred_backend = "flashinfer" + else: + self.preferred_backend = "cutlass" + else: + self.preferred_backend = "torch" # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. @@ -318,8 +377,7 @@ class Fp8LinearOp: if pad_output is None: config = get_current_vllm_config().compilation_config pad_output = config.level < CompilationLevel.PIECEWISE and \ - not cutlass_fp8_supported and \ - not current_platform.is_rocm() + self.preferred_backend == "torch" self.output_padding = 17 if pad_output else None self.act_quant_static = act_quant_static @@ -364,9 +422,9 @@ class Fp8LinearOp: per_tensor_activations = (x_scale.numel() == 1) # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.cutlass_fp8_supported, per_tensor_weights, - per_tensor_activations) + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend, + per_tensor_weights, + per_tensor_activations) return w8a8_scaled_mm_func(qinput=qinput, weight=weight, diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py new file mode 100644 index 0000000000000..05322e56f2620 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from .common import apply_rotary_emb_dispatch +from .mrope import MRotaryEmbedding + + +class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): + """3D rotary positional embedding. 3D is t:time h:height w:width""" + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + section_h = self.mrope_section[0] # 22 + section_w = self.mrope_section[1] # 22 + section_t = self.mrope_section[2] # 20 + assert section_h == section_w + # Split according to [h w h w h w h w... t t t...] + section_cos_t = cos[..., -section_t:] + section_cos_h = cos[..., :section_h + section_w:2] + section_cos_w = cos[..., 1:section_h + section_w:2] + + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ + 1], section_cos_w[2] + cos_hw = torch.stack([cos_h, cos_w], + dim=-1).reshape(cos_h.shape[:-1] + + (cos_h.shape[-1] * 2, )) + cos = torch.cat([cos_hw, cos_t], dim=-1) + + section_sin_t = sin[..., -section_t:] + section_sin_h = sin[..., :section_h + section_w:2] + section_sin_w = sin[..., 1:section_h + section_w:2] + + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ + 1], section_sin_w[2] + sin_hw = torch.stack([sin_h, sin_w], + dim=-1).reshape(sin_h.shape[:-1] + + (sin_h.shape[-1] * 2, )) + sin = torch.cat([sin_hw, sin_t], dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index a091cfb743291..e374aa9bebf9e 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -393,6 +393,15 @@ class MRotaryEmbedding(RotaryEmbedding): context_len=context_len, seq_len=seq_len, ) + elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: + return cls._ernie_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) else: return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, @@ -513,6 +522,120 @@ class MRotaryEmbedding(RotaryEmbedding): len(input_tokens)).item() return llm_positions, mrope_position_delta + @classmethod + def _ernie_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for Ernie VL.""" + + image_token_id = hf_config.im_patch_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_conv_size, w // spatial_conv_size + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_grid_thw[mm_data_idx][0], + video_grid_thw[mm_data_idx][1], + video_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = (t // + temporal_conv_size, + h // + spatial_conv_size, + w // + spatial_conv_size) + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view( + 1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view( + 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + return llm_positions, mrope_position_delta + @classmethod def _vl_get_input_positions_tensor( cls, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 48a347a8f5611..2897f75b3129e 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -142,6 +142,12 @@ direct_register_custom_op( ) +def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype): + return (torch._C._cpu._is_amx_tile_supported() + and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 + and n % 16 == 0) + + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 2b8e4427591c1..34b8d8e4ed622 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -207,16 +207,21 @@ class DefaultModelLoader(BaseModelLoader): ) if current_platform.is_tpu(): - # In PyTorch XLA, we should call `xm.mark_step` frequently so that - # not too many ops are accumulated in the XLA program. - import torch_xla.core.xla_model as xm + from vllm.platforms.tpu import USE_TPU_COMMONS - def _xla_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - xm.mark_step() + if not USE_TPU_COMMONS: + # In PyTorch XLA, we should call `xm.mark_step` + # requently so that not too many ops are accumulated + # in the XLA program. import torch_xla.core.xla_model + # as xm + import torch_xla.core.xla_model as xm - weights_iterator = _xla_weights_iterator(weights_iterator) + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 21655b0c69bb4..9877cb3b7c06e 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( - get_gguf_extra_tensor_names, gguf_quant_weights_iterator) + get_gguf_extra_tensor_names, get_gguf_weight_type_map, + gguf_quant_weights_iterator) class GGUFModelLoader(BaseModelLoader): @@ -132,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader): local_model_path, gguf_weights_map): model_config.hf_config.update({"tie_word_embeddings": True}) + weight_type_map = get_gguf_weight_type_map(model_config.model, + gguf_weights_map) + + # filter out unquantized modules to skip + unquant_names = [ + name.removesuffix(".weight") + for name, weight_type in weight_type_map.items() + if weight_type == "F32" and name.endswith(".weight") + ] + vllm_config.quant_config.unquantized_modules.extend(unquant_names) + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index b44c165397d02..a70cdeb483e67 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -98,14 +98,15 @@ class TPUModelLoader(DefaultModelLoader): # Check parameters for name, param in model.named_parameters(): - assert param.device.type == device_type, f"Parameter {name} is on \ - {param.device.type} instead of {device_type}" + assert param.device.type == device_type, ( + f"Parameter {name} is on {param.device.type} " + f"instead of {device_type}") # Check buffers for name, buffer in model.named_buffers(): - assert buffer.device.type == device_type, \ - f"Buffer {name} is on {buffer.device.type} instead of \ - {device_type}" + assert buffer.device.type == device_type, ( + f"Buffer {name} is on {buffer.device.type} " + f"instead of {device_type}") for module in model.modules(): if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'): diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 78b186265dd04..3bb47f82d2f37 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -31,9 +31,7 @@ from vllm.utils import PlaceholderModule try: from runai_model_streamer import SafetensorsStreamer -except (ImportError, OSError): - # see https://github.com/run-ai/runai-model-streamer/issues/26 - # OSError will be raised on arm64 platform +except ImportError: runai_model_streamer = PlaceholderModule( "runai_model_streamer") # type: ignore[assignment] SafetensorsStreamer = runai_model_streamer.placeholder_attr( @@ -565,6 +563,18 @@ def get_gguf_extra_tensor_names( return [gguf_to_hf_name_map[key] for key in extra_keys] +def get_gguf_weight_type_map( + gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]: + """ + Return GGUF mapped weight's name and its quant type + """ + reader = gguf.GGUFReader(gguf_file) + return { + gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name + for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map + } + + def gguf_quant_weights_iterator( gguf_file: str, gguf_to_hf_name_map: dict[str, str] ) -> Generator[tuple[str, torch.Tensor], None, None]: diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1dbe70f84a626..49e9a2d65ea11 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -7,15 +7,21 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig +from vllm.transformers_utils.config import (get_hf_file_bytes, + get_hf_file_to_dict) from .interfaces_base import VllmModelForPooling, is_pooling_model if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig _T = TypeVar("_T", bound=type[nn.Module]) +logger = init_logger(__name__) + _GENERATE_SUFFIXES = [ "ForCausalLM", "ForConditionalGeneration", @@ -24,6 +30,96 @@ _GENERATE_SUFFIXES = [ ] +def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: + """Load Sentence-Transformers Dense projection layers.""" + + try: + modules = get_hf_file_to_dict("modules.json", model_config.model, + model_config.revision) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [ + m for m in modules + if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_modules: + return None + + module = dense_modules[0] + folder = module.get("path", "") + + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model_config.model, + model_config.revision) + if not layer_config: + return None + + linear = nn.Linear(layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + dtype=torch.float32) + + if _load_dense_weights(linear, folder, model_config): + layers = [linear] + if act_name := layer_config.get("activation_function"): + layers.append(get_act_fn(act_name)) + return nn.Sequential(*layers).to(dtype=torch.float32) + + except Exception: + logger.exception("ST projector loading failed") + + return None + + +def _load_dense_weights(linear: nn.Linear, folder: str, + model_config: "ModelConfig") -> bool: + """Load weights using vLLM's weight_loader pattern.""" + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) + + for filename in ["model.safetensors", "pytorch_model.bin"]: + file_path = f"{folder}/{filename}" if folder else filename + + try: + file_bytes = get_hf_file_bytes(file_path, model_config.model, + model_config.revision) + if not file_bytes: + continue + + if filename.endswith(".safetensors"): + from safetensors.torch import load as load_safetensors + state_dict = load_safetensors(file_bytes) + else: + import io + state_dict = torch.load(io.BytesIO(file_bytes), + map_location="cpu", + weights_only=True) + + for weight_key in ["weight", "linear.weight", "dense.weight"]: + if weight_key in state_dict: + weight_loader = getattr(linear.weight, "weight_loader", + default_weight_loader) + weight_loader(linear.weight, + state_dict[weight_key].to(torch.float32)) + + bias_key = weight_key.replace("weight", "bias") + if linear.bias is not None and bias_key in state_dict: + bias_loader = getattr(linear.bias, "weight_loader", + default_weight_loader) + bias_loader(linear.bias, + state_dict[bias_key].to(torch.float32)) + return True + except Exception: + logger.exception("Failed to load %s", filename) + continue + + return False + + def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: model_name = orig_model_name diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index e1368a3f6478a..1c7960fa3e0a5 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -470,7 +470,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 5cd74bbba4827..687c82ded9d0a 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -18,7 +18,7 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( from vllm.config import VllmConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -242,7 +242,7 @@ class AyaVisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token @@ -250,8 +250,7 @@ class AyaVisionMultiModalProcessor( image_processor = hf_processor.image_processor def get_replacement(item_idx: int): - images: ImageProcessorItems = mm_items.get("image", - ImageProcessorItems) + images = mm_items.get_items("image", ImageProcessorItems) image_size: ImageSize = images.get_image_size(item_idx) num_patches = self.info.get_num_patches( image_width=image_size.width, diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 3d328c88ff6e0..32551d8102f32 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -46,7 +46,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsQuant, SupportsV0Only -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, + maybe_prefix) logger = logging.get_logger(__name__) @@ -422,10 +423,7 @@ class BartEncoderLayer(nn.Module): if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) + hidden_states = cast_overflow_tensors(hidden_states) return hidden_states @@ -906,3 +904,439 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): }) return loaded_params + + +class MBartEncoderLayer(BartEncoderLayer): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Args: + hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Encoder layer output torch.Tensor + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + hidden_states = cast_overflow_tensors(hidden_states) + + return hidden_states + + +class MBartDecoderLayer(BartDecoderLayer): + + def forward( + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = decoder_hidden_states + hidden_states = self.self_attn_layer_norm(decoder_hidden_states) + + # Self Attention + hidden_states = self.self_attn(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + + # Cross-Attention Block + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + + return hidden_states + + +class MBartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = ""): + super().__init__() + + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + embed_dim = config.d_model + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([ + MBartEncoderLayer(config, + cache_config, + quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.encoder_layers) + ]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + self.layer_norm = nn.LayerNorm(config.d_model) # 改动 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *encoder* input sequence tokens. + Returns: + Decoder output torch.Tensor + """ + # retrieve input_ids and inputs_embeds + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states=hidden_states) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class MBartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = "", + ): + super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + + self.layers = nn.ModuleList( + [MBartDecoderLayer(config, cache_config, quant_config, + prefix=f"{prefix}.layers.{layer_idx}") \ + for layer_idx in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + decoder_positions + Positions of *decoder* input sequence tokens. + encoder_hidden_states: + Tensor of encoder output embeddings + Returns: + Decoder output torch.Tensor + """ + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(decoder_input_ids) + else: + decoder_positions = inputs_embeds[:, -1] + + # embed positions + embed_pos = self.embed_positions(decoder_positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class MBartModel(nn.Module, SupportsQuant): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.encoder = MBartEncoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states) + + return decoder_outputs + + +class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): + base_model_prefix = "model" + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder.": "model.decoder.", + "encoder.": "model.encoder.", + "shared.": "model.shared." + }, + orig_to_new_substr={ + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + assert config.tie_word_embeddings + self.config = config + self.model = MBartModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.lm_head = BartParallelLMHead(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + return self.model(input_ids, positions, encoder_input_ids, + encoder_positions) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + model_params_dict = dict(self.named_parameters()) + loaded_params = set() + remaining_weights = [] + shared_embedding_weight = None + + for name, loaded_weight in weights: + if any(skip in name + for skip in ["cls.", "pooler.", "final_logits_bias"]): + continue + if any(embed_name in name for embed_name in [ + 'shared.weight', 'encoder.embed_tokens.weight', + 'decoder.embed_tokens.weight' + ]): + if shared_embedding_weight is None: + shared_embedding_weight = loaded_weight + continue + is_stacked = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + vllm_name = name + for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items( + ): + vllm_name = vllm_name.replace(src, dst) + for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items( + ): + if vllm_name.startswith(src): + vllm_name = dst + vllm_name[len(src):] + break + vllm_name = vllm_name.replace(weight_name, param_name) + if vllm_name in model_params_dict: + param = model_params_dict[vllm_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(vllm_name) + is_stacked = True + break + if not is_stacked: + remaining_weights.append((name, loaded_weight)) + loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."]) + auto_loaded_params = loader.load_weights(remaining_weights, + mapper=self.hf_to_vllm_mapper) + loaded_params.update(auto_loaded_params) + if shared_embedding_weight is not None: + lm_head_param = self.lm_head.weight + weight_loader = getattr(lm_head_param, "weight_loader", + default_weight_loader) + weight_loader(lm_head_param, shared_embedding_weight) + self.model.encoder.embed_tokens.weight = self.lm_head.weight + self.model.decoder.embed_tokens.weight = self.lm_head.weight + loaded_params.update({ + 'model.encoder.embed_tokens.weight', 'lm_head.weight', + 'model.decoder.embed_tokens.weight' + }) + return loaded_params diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6638f06f98261..b34ca5cbe963d 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -8,7 +8,7 @@ import torch from torch import nn from transformers import BertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -28,8 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import (SupportsCrossEncoding, SupportsQuant, - default_pooling_type) +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.qkv_proj") - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -528,9 +527,9 @@ def _encode_token_type_ids(input_ids: torch.Tensor, def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: - ids_mask = torch.ones(input_ids.shape, - dtype=torch.int32, - device=input_ids.device) << TOKEN_TYPE_SHIFT + ids_mask = torch.ones_like(input_ids, + dtype=torch.int32, + device=input_ids.device) << TOKEN_TYPE_SHIFT tokens_mask = ids_mask.bitwise_not() token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index e18b7b7ffabab..dcb7e75456cde 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -7,7 +7,7 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -27,13 +27,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (SupportsQuant, - default_pooling_type) from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from .interfaces import SupportsQuant +from .interfaces_base import default_pooling_type + class BertWithRopeEmbedding(nn.Module): @@ -119,14 +120,13 @@ class BertWithRopeAttention(nn.Module): self.rotary_emb = get_rope(**rotary_kwargs) - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") self.out_proj = RowParallelLinear(input_size=hidden_size, output_size=hidden_size, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8e3505f872eb2..2f2b880bb0e14 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, @@ -492,7 +492,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 6e4a399f3cc6e..126404584892f 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -43,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only +from .interfaces import SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -313,7 +313,7 @@ class BloomModel(nn.Module): return loaded_params -class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): +class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 8d705f40ce8ff..e6914ad4c495d 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -31,7 +31,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -151,7 +151,7 @@ class ChameleonMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index f17583768f795..179cc2af8eb3f 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -10,6 +10,8 @@ import torch from torch import nn from transformers import BatchFeature, PretrainedConfig from transformers.models.cohere2_vision import Cohere2VisionConfig +from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501 + get_optimal_tiled_canvas) from transformers.models.cohere2_vision.processing_cohere2_vision import ( Cohere2VisionProcessor) @@ -21,7 +23,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -150,14 +152,48 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo): max_patches = image_processor.max_patches return ImageSize(height=height * max_patches, width=width) - def get_num_patches(self, image_width: int, image_height: int) -> int: + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Cohere2VisionProcessor], + ) -> int: """ Calculate the number of image patches for a given image. Uses the HF processor to determine the actual number of patches. """ - return self.get_hf_processor( - ).image_processor.get_number_of_image_patches(image_height, - image_width, {}) + if processor is None: + processor = self.get_hf_processor() + + image_processor = processor.image_processor + + # The current implementation of get_number_of_image_patches + # is incorrect, so we patch it here. + # TODO: Revert once + # https://github.com/huggingface/transformers/pull/40312 is released. + # return image_processor.get_number_of_image_patches(image_height, + # image_width, {}) + + min_patches = image_processor.min_patches + max_patches = image_processor.max_patches + patch_size = image_processor.size + crop_to_patches = image_processor.crop_to_patches + + if not crop_to_patches: + return 1 + + num_columns, num_rows = get_optimal_tiled_canvas( + (image_height, image_width), + (patch_size["height"], patch_size["width"]), + min_patches, + max_patches, + ) + num_patches = num_columns * num_rows + if num_patches > 1: + num_patches += 1 # Thumbnail image + + return num_patches class Cohere2VisionDummyInputsBuilder( @@ -208,6 +244,8 @@ class Cohere2VisionMultiModalProcessor( # Ensure num_patches is available for proper tensor splitting if "num_patches" not in processed_outputs and ( images := mm_data.get("images")) is not None: + hf_processor = self.info.get_hf_processor(**mm_kwargs) + # Fallback calculation if HF processor didn't provide num_patches parsed_images = self._get_data_parser().parse_mm_data({ "image": @@ -217,8 +255,9 @@ class Cohere2VisionMultiModalProcessor( num_patches = [ self.info.get_num_patches( image_width=parsed_images.get_image_size(i).width, - image_height=parsed_images.get_image_size(i).height) - for i in range(len(parsed_images)) + image_height=parsed_images.get_image_size(i).height, + processor=hf_processor, + ) for i in range(len(parsed_images)) ] processed_outputs["num_patches"] = torch.tensor(num_patches) @@ -241,29 +280,29 @@ class Cohere2VisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token + img_tokens_per_tile = int(hf_processor.patch_size**2) img_line_break_token = hf_processor.img_line_break_token boi_token = hf_processor.boi_token eoi_token = hf_processor.eoi_token def get_replacement(item_idx: int): - images: ImageProcessorItems = mm_items.get("image", - ImageProcessorItems) + images = mm_items.get_items("image", ImageProcessorItems) image_size: ImageSize = images.get_image_size(item_idx) - num_patches = self.info.get_num_patches(image_size.height, - image_size.width) - img_tokens_per_tile = int(hf_processor.patch_size**2) - single_tile_tokens = image_token * img_tokens_per_tile + \ - img_line_break_token - img_string = f"{boi_token}\ - {single_tile_tokens * num_patches}\ - {eoi_token}" + num_patches = self.info.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + patch_tokens = (image_token * img_tokens_per_tile + + img_line_break_token) + repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}" - return PromptUpdateDetails.select_text(img_string, image_token) + return PromptUpdateDetails.select_text(repl, image_token) return [ PromptReplacement( @@ -311,7 +350,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=["Cohere2ForCausalLM"]) + architectures=config.text_config.architectures) @property def dtype(self): diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 882df7e8162c5..88b3154de2cbb 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -4,6 +4,7 @@ from copy import deepcopy from typing import TYPE_CHECKING import vllm.envs as envs +from vllm.config.compilation import CUDAGraphMode from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv @@ -275,6 +276,43 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig): "%d for performance.", 1024) +class MambaModelConfig(VerifyAndUpdateConfig): + + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Enable FULL_AND_PIECEWISE cuda graph mode by default (required + to get good performance for mamba layers in V1). + + Args: + vllm_config: vLLM Config + """ + + if not envs.VLLM_USE_V1: + return + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + + # TODO(tdoublep): remove once prefix caching is enabled + cache_config.enable_prefix_caching = False + logger.info("Hybrid or mamba-based model detected: disabling prefix " + "caching since it is not yet supported.") + + # TODO(tdoublep): remove as full cuda graph support is added + FCG_NOT_SUPPORTED_MODELS = [ + "Lfm2ForCausalLM", "MiniMaxText01ForCausalLM" + ] + + if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS + and compilation_config.cudagraph_mode is None): + logger.info( + "Hybrid or mamba-based model detected: setting cudagraph mode " + "to FULL_AND_PIECEWISE in order to optimize performance.") + compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE + + class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod @@ -293,6 +331,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): if not envs.VLLM_USE_V1: return + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + cache_config = vllm_config.cache_config model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -374,4 +415,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "JambaForSequenceClassification": JambaForSequenceClassificationConfig, "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, "GptOssForCausalLM": GptOssForCausalLMConfig, + "MambaForCausalLM": MambaModelConfig, + "Mamba2ForCausalLM": MambaModelConfig, } diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py new file mode 100644 index 0000000000000..0c9c83cf61000 --- /dev/null +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, + DeepseekV3ForCausalLM) +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from .utils import AutoWeightsLoader, maybe_prefix + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + ) -> None: + super().__init__() + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.vocab_size = self.config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + DeepseekV2DecoderLayer( + self.config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ) for i in range(self.config.num_hidden_layers) + ]) + + self.fc = nn.Linear( + self.config.model.hidden_size * 2, + self.config.model.hidden_size, + bias=False, + ) + + self.enorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.hnorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + + inputs = torch.cat( + [self.enorm(input_embeds), + self.hnorm(hidden_states)], dim=-1) + hidden_states = self.fc(inputs) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ((param_name == "fused_qkv_a_proj") + and name_mapped not in params_dict): + continue + else: + name = name_mapped + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + quant_config = vllm_config.quant_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num) + + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) + return self.model(input_ids, positions, hidden_states) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + ) + + model_weights = {} + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 2e026d582a6de..0ad001be71c19 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -158,14 +158,13 @@ class DeepSeekMTP(nn.Module, SupportsPP): self, input_ids: torch.Tensor, positions: torch.Tensor, - previous_hidden_states: torch.Tensor, + hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, - previous_hidden_states, inputs_embeds, - spec_step_idx) + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( @@ -213,13 +212,15 @@ class DeepSeekMTP(nn.Module, SupportsPP): # for mlp.experts[0].gate_gate_up_proj, which breaks load. if (("mlp.experts." in name) and name not in params_dict): continue - name = name.replace(weight_name, param_name) + name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled if ((param_name == "fused_qkv_a_proj") - and name not in params_dict): + and name_mapped not in params_dict): continue + else: + name = name_mapped # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f199da135ec76..7657e7cb003d6 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -126,16 +126,16 @@ class DeepseekV2MoE(nn.Module): prefix=f"{prefix}.gate") if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) + torch.empty(config.n_routed_experts, dtype=torch.float32)) else: self.gate.e_score_correction_bias = None # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index e0acca75d9dd6..ceb5e1364b68d 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -21,11 +21,12 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -252,7 +253,7 @@ class DeepseekVL2MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -289,9 +290,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only @@ -302,7 +301,6 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) return super()._cached_apply_hf_processor( @@ -310,7 +308,6 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py new file mode 100644 index 0000000000000..c00db52371b68 --- /dev/null +++ b/vllm/model_executor/models/donut.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, Optional, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature, NougatProcessor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsV0Only) +from vllm.model_executor.models.swin import SwinModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, + _flatten_embeddings, flatten_bn) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape + + +class MBartDecoderWrapper(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = MBartDecoderWrapper(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + Returns: + Output torch.Tensor + """ + + return self.model(decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + # if self.config.tie_word_embeddings and "embed_tokens" in name: + # continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DonutImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] + + +class DonutProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + return 1 + + +class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_hf_config( + ).encoder.image_size + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + hf_processor = self.info.get_hf_processor() + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs) + if isinstance(hf_processor, NougatProcessor): + processed_outputs["input_ids"] = processed_outputs["labels"] + else: + tokenizer = hf_processor.tokenizer + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + pad_token_id = tokenizer.pad_token_id + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor, + info=DonutProcessingInfo, + dummy_inputs=DonutDummyInputsBuilder) +class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.encoder + self.processor_config = processor_config + self.encoder = SwinModel(config=config.encoder) + + self.decoder = DonutLanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.decoder), + prefix=f"{prefix}.decoder", + ) + self.pad_token_id = config.pad_token_id + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + h, w = self.config.encoder.image_size + return DonutImagePixelInputs(type="pixel_values", + data=flatten_bn(pixel_values, + concat=True), + resolve_bindings={ + "h": h, + "w": w, + }) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: DonutImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + dtype = next(self.encoder.parameters()).dtype + pixel_values = pixel_values.to(dtype) + return self.encoder(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + ) -> torch.Tensor: + return _flatten_embeddings(multimodal_embeddings) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + inputs_embeds = None + if encoder_input_ids.numel() > 0: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + + hidden_states = self.decoder(input_ids, + positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py new file mode 100644 index 0000000000000..d880fc434e20f --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl.py @@ -0,0 +1,1504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import _Backend, current_platform +from vllm.sequence import IntermediateTensors + +from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +_MAX_FRAMES_PER_VIDEO = 16 + +# === Vision Transformer === # + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + apply_rotary_emb = apply_rotary_emb_torch + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) + return output + + +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather(gathered_tensors, + local_tensor, + group=parallel_state.get_tp_group().device_group) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) + for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + +class Ernie4_5_VisionAttention(nn.Module): + """VisionAttention using VLLM framework APIs""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Ernie45-VL does not support {self.attn_backend} backend now." + ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, + self.tp_size) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.is_flash_attn_backend: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Ernie4_5_VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + act_layer: type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.act = act_layer() + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +class Ernie4_5_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Ernie4_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.mlp = Ernie4_5_VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Ernie4_5_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1280, + prefix="", + ) -> None: + + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_channels * patch_size * patch_size, + embed_dim, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.to(target_dtype) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class Ernie4_5_VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.inv_freq = 1.0 / theta**( + torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(input=seq, vec2=self.inv_freq) + return freqs + + +class Ernie4_5_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + + super().__init__() + patch_size = vision_config.patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + self.patch_embed = Ernie4_5_VisionPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + prefix=f"{prefix}.patch_embed", + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Ernie4_5_VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + assert (hidden_size == embed_dim + ), "vit's config.hidden must be equal to config.embed_dim" + self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward(self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + num_pad=0) -> torch.Tensor: + + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + + if num_pad > 0: + cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) + cu_seqlens[-1] = cu_seqlens[-2] + num_pad + else: + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # add batch size + if hidden_states.ndim == 2: + hidden_states = hidden_states.unsqueeze(dim=1) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + for i, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + final_output = self.ln(hidden_states) + + if final_output.ndim == 3: + final_output = final_output.squeeze(dim=1) + + return final_output + + def load_weights(self, weights) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# === Vision Inputs === # + + +class Ernie4_5_VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs + + +class Ernie4_5_VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs + +# === Vision Processor === # + + +def round_by_factor(number: Union[int, float], factor: int) -> int: + return round(number / factor) * factor + + +def ceil_by_factor(number: Union[int, float], factor: int) -> int: + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: Union[int, float], factor: int) -> int: + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 4 * 28 * 28, + max_pixels: int = 16384 * 28 * 28, +): + MAX_RATIO = 200 + if max(height, width) / min(height, width) > MAX_RATIO: + if height > width: + new_width = max(factor, round_by_factor(width, factor)) + new_height = floor_by_factor(new_width * MAX_RATIO, factor) + else: + new_height = max(factor, round_by_factor(height, factor)) + new_width = floor_by_factor(new_height * MAX_RATIO, factor) + + height = new_height + width = new_width + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: + raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") + + return h_bar, w_bar + + +class VariableResolutionResamplerModel(nn.Module): + + def __init__(self, + in_dim, + out_dim, + spatial_conv_size, + temporal_conv_size, + config, + prefix: str = "") -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.config = config + self.spatial_conv_size = spatial_conv_size + self.temporal_conv_size = temporal_conv_size + self.use_temporal_conv = config.use_temporal_conv + + # compress 2d conv(picture) to 1d + self.spatial_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size) + # compress 3d conv(video) to 1d + self.temporal_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size * self.temporal_conv_size) + + self.spatial_linear1 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear1", + ) + + self.spatial_gelu = nn.GELU() + + self.spatial_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear2", + ) + + self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + if self.use_temporal_conv: + self.temporal_linear1 = ColumnParallelLinear( + self.temporal_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear1", + ) + + self.temporal_gelu = nn.GELU() + + self.temporal_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear2", + ) + + self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + self.mlp = ColumnParallelLinear( + self.spatial_dim, + self.out_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.mlp", + ) + + self.after_norm = RMSNorm(hidden_size=out_dim, + eps=getattr(config, 'rms_norm_eps', 1e-6)) + + def spatial_conv_reshape(self, x, spatial_conv_size): + S, C = x.shape + x = x.reshape([-1, C * (spatial_conv_size**2)]) + return x + + def forward(self, x, grid_thw): + + def fwd_spatial(x): + x = self.spatial_conv_reshape(x, self.spatial_conv_size) + + x, _ = self.spatial_linear1(x) + x = self.spatial_gelu(x) + x, _ = self.spatial_linear2(x) + x = self.spatial_norm(x) + + return x + + def fwd_placeholder(x, grid_thw, to_tensor=False): + + grid_thw_cpu = grid_thw.cpu().numpy() + grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size** + 2) + + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // ( + self.spatial_conv_size**2) + batch_offset = np.empty(tokens_per_img_or_vid.size, + dtype=tokens_per_img_or_vid.dtype) + batch_offset[0] = 0 + batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] + + slice_offsets = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(0, temporoal_size, 2): + slice_offsets.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + )) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, + axis=-1)).to(x.device) + + slice_offsets2 = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(1 if temporoal_size > 1 else 0, + temporoal_size, 2): + slice_offsets2.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + )) + slice_offsets2 = torch.tensor( + np.concatenate(slice_offsets2, axis=-1)).to(x.device) + + x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) + x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) + x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) + return x + + def fwd_temporal(x): + x, _ = self.temporal_linear1(x) + x = self.temporal_gelu(x) + x, _ = self.temporal_linear2(x) + x = self.temporal_norm(x) + return x + + def fwd_mlp(x): + x, _ = self.mlp(x) + x = self.after_norm(x) + return x + + x = fwd_spatial(x) + if self.use_temporal_conv: + x = fwd_placeholder(x, grid_thw) + x = fwd_temporal(x) + x = fwd_mlp(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(use_fast=True, **kwargs) + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + image_processor: Optional[Any], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + patch_size = vision_config.patch_size + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * spatial_conv_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + grid_t = max(num_frames // temporal_conv_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (spatial_conv_size**2) + + return preprocessed_size, num_vision_tokens + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor: Optional[Any], + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + image_processor=image_processor, + ) + return num_image_tokens + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[Any], + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + image_processor=image_processor, + ) + return num_video_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=9999999, + image_height=9999999, + image_processor=None, + ) + return max_image_size + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_image_tokens = self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + image_processor=None, + ) + return num_image_tokens + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + # If the number of frames is odd, discard one frame. + if num_frames % 2 != 0: + num_frames -= 1 + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 2) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ernie4_5VLMultiModalProcessor( + BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): + + def _pixel_values_norm( + self, + pixel_values: torch.Tensor, + mm_kwargs: object, + ) -> torch.Tensor: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + image_processor = self.info.get_image_processor(**mm_kwargs) + image_mean_tensor = torch.tensor(image_processor.image_mean, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + image_std_tensor = torch.tensor(image_processor.image_std, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + rescale_factor = torch.tensor(image_processor.rescale_factor, + dtype=torch.float32) + patch_size_squared = vision_config.patch_size**2 + + image_mean_tensor = (image_mean_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + image_std_tensor = (image_std_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + + if not image_mean_tensor.is_contiguous(): + image_mean_tensor = image_mean_tensor.contiguous() + if not image_std_tensor.is_contiguous(): + image_std_tensor = image_std_tensor.contiguous() + + pixel_values = (rescale_factor * pixel_values.to(torch.float32) - + image_mean_tensor) / image_std_tensor + pixel_values = pixel_values.to(hf_config.torch_dtype) + return pixel_values + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # when the prompt is not empty but the multimodal data is empty, + # directly invoke the tokenizer. + if "images" not in mm_data and "videos" not in mm_data and prompt != "": + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]), + tensor_type="pt") + return tokenizer_output + + if "images" not in mm_data: + mm_data["images"] = [] + if "videos" not in mm_data: + mm_data["videos"] = [] + processor_output = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=[prompt], + images=mm_data["images"], + videos=mm_data["videos"]), + dict(**mm_kwargs, **tok_kwargs), + ) + + # Divide the processor_output into two modalities: image and video. + if processor_output is not None: + pixel_values = processor_output['images'] + if pixel_values is not None: + processor_output['images'] = self._pixel_values_norm( + pixel_values, mm_kwargs) + for key in list(processor_output.keys()): + if processor_output[key] is None: + del processor_output[key] + continue + if key == "grid_thw": + grid_thw = processor_output['grid_thw'] + pixel_values_all = processor_output['images'] + # Identify elements where the first + # dimension is greater than 1 and + # treat them as the video modality + mask = grid_thw[:, 0] > 1 + processor_output["video_grid_thw"] = grid_thw[mask] + processor_output["image_grid_thw"] = grid_thw[~mask] + image_patch_num = processor_output["image_grid_thw"].prod( + dim=1).sum() + processor_output[ + 'pixel_values'] = pixel_values_all[:image_patch_num] + processor_output['pixel_values_videos'] = pixel_values_all[ + image_patch_num:] + del processor_output['images'] + + return processor_output + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + before_placeholder = { + "image": "<|image@placeholder|>", + "video": "<|video@placeholder|>" + } + + after_placeholder = { + # image and video have same placeholder + "image": "<|IMAGE_PLACEHOLDER|>", + "video": "<|IMAGE_PLACEHOLDER|>" + } + + merge_length = hf_processor.spatial_conv_size**2 + + def get_replacement_ernie45vl(item_idx: int, modality: str): + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + if modality == "video": + num_tokens = int(grid_thw.prod( + )) // hf_processor.temporal_conv_size // merge_length + else: + num_tokens = int(grid_thw.prod()) // merge_length + return after_placeholder[modality] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=before_placeholder[modality], + replacement=partial(get_replacement_ernie45vl, + modality=modality), + ) for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + +class Ernie4_5_VLDummyInputsBuilder( + BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + prompt = "" + for i in range(num_images): + prompt += (f"Picture {i+1}:" + "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>") + + for i in range(num_videos): + prompt += (f"Video {i+1}:" + "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>") + return prompt + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos(width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos) + } + + +@MULTIMODAL_REGISTRY.register_processor( + Ernie4_5VLMultiModalProcessor, + info=Ernie4_5_VLProcessingInfo, + dummy_inputs=Ernie4_5_VLDummyInputsBuilder) +class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + # model.resampler_model.-> language_model.model.resampler_model. + # language_model.model.resampler_model. -> resampler_model. + "language_model.model.resampler_model.": "resampler_model.", + }, + # resampler_weight_mappings + orig_to_new_substr={ + "spatial_linear.0.": "spatial_linear1.", + "spatial_linear.2.": "spatial_linear2.", + "spatial_linear.3.": "spatial_norm.", + "temporal_linear.0.": "temporal_linear1.", + "temporal_linear.2.": "temporal_linear2.", + "temporal_linear.3.": "temporal_norm.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + if modality.startswith("video"): + return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_model = Ernie4_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = Ernie4_5_VLMoeForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.resampler_model = VariableResolutionResamplerModel( + self.config.pixel_hidden_size, + self.config.hidden_size, + self.config.spatial_conv_size, + self.config.temporal_conv_size, + config=self.config, + prefix=maybe_prefix(prefix, "resampler_model")) + + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + """compute logits""" + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def _vision_forward( + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + if grid_thw is not None: + grid_thw = grid_thw[grid_thw > 0] + if grid_thw.numel() % 3 != 0: + raise ValueError( + f"grid_thw has {grid_thw.numel()} elements after filtering," + "which is not divisible by 3.") + grid_thw = grid_thw.reshape(-1, 3) + # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]] + grid_thw = F.pad( + torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), + [1, 0, 0, 0], + value=1, + ) + image_features = self.vision_model(pixel_values, grid_thw) + return image_features + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if getattr(self.config, "im_patch_id", None) is not None: + self.visual_token_mask = ( + input_ids == self.config.im_patch_id).reshape(-1, 1) + else: + self.visual_token_mask = None + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Ernie4_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Ernie4_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input( + self, + image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type( + self.vision_model.dtype) + image_features = self._vision_forward(pixel_values=pixel_values, + grid_thw=grid_thw) + image_embeds = self.resampler_model(image_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.vision_model.dtype) + video_features = self._vision_forward(pixel_values=pixel_values_videos, + grid_thw=grid_thw) + video_embeds = self.resampler_model(video_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = (grid_thw.prod(-1) // + self.config.temporal_conv_size) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is None: + return inputs_embeds + + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, + multimodal_embeddings, + [self.config.im_patch_id]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + if self.visual_token_mask is not None: + + if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: + padding_len = inputs_embeds.shape[ + 0] - self.visual_token_mask.shape[0] + # right pad False + pad = torch.zeros( + (padding_len, self.visual_token_mask.shape[1]), + dtype=self.visual_token_mask.dtype, + device=self.visual_token_mask.device) + self.visual_token_mask = torch.cat( + [self.visual_token_mask, pad], dim=0) + + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model( + **forward_kwargs, + **kwargs, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py new file mode 100644 index 0000000000000..f56c098435154 --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -0,0 +1,723 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +# from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .ernie45_moe import Ernie4_5_MoeMLP +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): + pass + + +class Ernie4_5_VLMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + freq_allocation: int = 20, + max_position_embeddings: int = 131072, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0 + self.layer_idx = layer_idx + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + t_rope = freq_allocation + h_rope = (self.head_dim // 2 - freq_allocation) // 2 + w_rope = (self.head_dim // 2 - freq_allocation) // 2 + + self.rotary_emb = Ernie4_5_VLRotaryEmbedding( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + dtype=torch.get_default_dtype(), + mrope_section=[h_rope, w_rope, t_rope]) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.qkv_proj(hidden_states) + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + # Attention + attn_output = self.attn(q, k, v) + # Output projection + output, _ = self.o_proj(attn_output) + return output + + +class Ernie4_5_VLMoeMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) + > 0) + self.hidden_size = config.hidden_size + + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + + if self.tp_size > max_moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {moe_num_experts}.") + + moe_layer_start_index = config.moe_layer_start_index + text_moe_layer_start_index = moe_layer_start_index[0] + vision_moe_layer_start_index = moe_layer_start_index[1] + moe_layer_end_index = config.moe_layer_end_index + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + text_moe_layer_end_index = moe_layer_end_index[0] + vision_moe_layer_end_index = moe_layer_end_index[1] + + assert config.moe_num_experts[0] == config.moe_num_experts[1] + self.e_score_correction_bias = nn.Parameter( + torch.empty(2, config.moe_num_experts[0])) + + assert text_moe_layer_start_index <= text_moe_layer_end_index + + if layer_idx >= text_moe_layer_start_index and \ + layer_idx <= text_moe_layer_end_index: + self.text_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[0], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.text_experts_gate") + + self.text_experts = FusedMoE( + num_experts=config.moe_num_experts[0], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[0], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[0], + prefix=f"{prefix}.text_experts") + else: + self.text_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + assert vision_moe_layer_start_index <= vision_moe_layer_end_index + if layer_idx >= vision_moe_layer_start_index and \ + layer_idx <= vision_moe_layer_end_index: + self.vision_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[1], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.vision_experts_gate") + + self.vision_experts = FusedMoE( + num_experts=config.moe_num_experts[1], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[1], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[1], + prefix=f"{prefix}.vision_experts") + else: + self.vision_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + if self.has_shared_experts: + intermediate_size = (config.moe_intermediate_size[0] * + config.moe_num_shared_experts) + self.shared_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=self.text_experts. + must_reduce_shared_expert_outputs()) + + def forward( + self, + hidden_states: torch.Tensor, + visual_token_mask: torch.Tensor, + **kwargs: object, + ) -> torch.Tensor: + + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.has_shared_experts: + shared_output = self.shared_experts(hidden_states) + + if visual_token_mask is not None and visual_token_mask.any(): + # assert visual_token_mask.shape[0] != hidden_states.shape[0] + visual_token_mask = visual_token_mask.repeat( + 1, self.hidden_size).bool() + text_token_mask = ~visual_token_mask + final_hidden_states = torch.zeros_like(hidden_states) + + text_hidden_states = hidden_states[text_token_mask].reshape( + -1, self.hidden_size) + vision_hidden_states = hidden_states[visual_token_mask].reshape( + -1, self.hidden_size) + + text_router_logits, _ = self.text_experts_gate(text_hidden_states) + final_hidden_states[text_token_mask] = self.text_experts( + hidden_states=text_hidden_states, + router_logits=text_router_logits).flatten() + + vision_router_logits, _ = self.vision_experts_gate( + vision_hidden_states) + final_hidden_states[visual_token_mask] = self.vision_experts( + hidden_states=vision_hidden_states, + router_logits=vision_router_logits).flatten() + else: + # text modal input processing directly + text_router_logits, _ = self.text_experts_gate(hidden_states) + + final_hidden_states = self.text_experts( + hidden_states=hidden_states, router_logits=text_router_logits) + + if self.has_shared_experts and \ + shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = ( + self.text_experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + + return final_hidden_states.view(orig_shape) + + +class Ernie4_5_VLMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + freq_allocation = getattr(config, "freq_allocation", 20) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + + self.self_attn = Ernie4_5_VLMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, 'head_dim', None), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + freq_allocation=freq_allocation, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'use_bias', False), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # MoE + moe_layer_start_index = config.moe_layer_start_index + min_moe_layer_start_index = min(moe_layer_start_index) + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + max_moe_layer_end_index = max(moe_layer_end_index) + assert min_moe_layer_start_index <= max_moe_layer_end_index + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + moe_layer_interval = getattr(config, "moe_layer_interval", 1) + use_moe = getattr(config, "use_moe", max_moe_num_experts > 0) + + if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= min_moe_layer_start_index + and layer_idx <= max_moe_layer_end_index): + self.mlp = Ernie4_5_VLMoeMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor], + **kwargs: object, + ) -> torch.Tensor: + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, Ernie4_5_VLMoeMoE): + hidden_states = self.mlp(hidden_states, visual_token_mask, + **kwargs) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# Since Ernie VL distinguishes between text experts and vision experts, +# enabling torch.compile will cause errors. +# @support_torch_compile( +# dynamic_arg_dims={ +# "input_ids": 0, +# "positions": -1, +# "intermediate_tensors": 0, +# "inputs_embeds": 0, +# "visual_token_mask": 0, +# }) +class Ernie4_5_VLMoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.im_patch_id = config.im_patch_id + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Ernie4_5_VLMoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual, + visual_token_mask, **kwargs) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +# only used as text backbone for ernie4.5-vl +class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=max(self.config.moe_num_experts)) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + loaded_params.add("lm_head.weight") + continue + # MTP will be supported soon. + if "mtp" in name or \ + "vision_model" in name or \ + "resampler_model" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Distinguish between vision experts and text experts + if "mlp.experts" in name: + moe_offset = int(name.split(".")[-3]) + vision_expert_start_idx = self.config.moe_num_experts[0] + is_text_expert = \ + moe_offset <= vision_expert_start_idx - 1 + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace( + f".experts.{moe_offset}", + f".vision_experts.{moe_offset-vision_expert_start_idx}" + ) + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + # Distinguish between vision experts and text experts + moe_offset = int(name.split(".")[-3]) + is_text_expert = \ + moe_offset <= self.config.moe_num_experts[0] - 1 + + name = name.replace(weight_name, param_name) + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace(".experts.", ".vision_experts.") + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Distinguish between vision expert gate + # and text expert gate + if name.endswith("mlp.gate.weight"): + name = name.replace("gate.weight", + "text_experts_gate.weight") + loaded_weight = loaded_weight.T + elif name.endswith("mlp.gate.weight_1"): + name = name.replace("gate.weight_1", + "vision_experts_gate.weight") + loaded_weight = loaded_weight.T + + if "e_score_correction_bias" in name: + name = name.replace(".moe_statics.", ".") + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py new file mode 100644 index 0000000000000..90a1267b28f0a --- /dev/null +++ b/vllm/model_executor/models/ernie_mtp.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Ernie-MTP model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .llama import LlamaDecoderLayer +from .utils import is_pp_missing_parameter, maybe_prefix + + +class ErnieMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.mtp_emb_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.mtp_hidden_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, + prefix) + + def forward( + self, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + + inputs_embeds = self.mtp_emb_norm(inputs_embeds) + previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states) + + hidden_states = self.mtp_linear_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + + return hidden_states + + +class ErnieMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + ErnieMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + inputs_embeds, + positions, + previous_hidden_states, + spec_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + lm_head: ParallelLMHead, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + logits = self.logits_processor(lm_head, hidden_states, + sampling_metadata) + return logits + + +class ErnieMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config = vllm_config.model_config.hf_config + self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + self.sampler = get_sampler() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + assert spec_step_idx == 0, "ernie_mtp only support predict one token" + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, self.lm_head, + sampling_metadata, spec_step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + continue + if "rotary_emb.inv_freq" in name: + continue + if "mtp" in name: + name = self._rewrite_spec_layer_name(self.config, name) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if "mtp" not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if "mtp_" not in name and ("embed_tokens" not in name + and "lm_head" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, config: PretrainedConfig, + name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + """ + spec_layer_weight_names = [ + "embed_tokens", "mtp_emb_norm", "mtp_hidden_norm", + "mtp_linear_proj" + ] + layer_idx = config.num_hidden_layers + for weight_name in spec_layer_weight_names: + if weight_name in name: + name = name.replace( + f"model.{weight_name}.0.", + f"model.layers.{layer_idx}.{weight_name}.") + return name + name = name.replace("model.mtp_block.0.", + f"model.layers.{layer_idx}.mtp_block.") + return name diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 56e456c2f1f2a..d0881231fb1e7 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -21,7 +21,7 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, @@ -647,7 +647,8 @@ class Florence2LanguageModel(nn.Module): encoder_hidden_states = None - if inputs_embeds is not None or encoder_input_ids.numel() > 0: + if ((inputs_embeds is not None and inputs_embeds.numel() > 0) + or encoder_input_ids.numel() > 0): # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, @@ -681,6 +682,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): self.lm_head = BartParallelLMHead(self.vocab_size, config.d_model, embed_scale=embed_scale) + if self.config.tie_word_embeddings: + self.lm_head.tie_weights(self.model.shared) self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) @@ -749,7 +752,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): else: if "final_logits_bias" in name: continue - if self.config.tie_word_embeddings and "embed_tokens" in name: + if self.config.tie_word_embeddings and ("embed_tokens" in name + or "lm_head" in name): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", @@ -860,7 +864,7 @@ class Florence2MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() pad_token_id = hf_config.pad_token_id diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index b61e0361fe8c3..90af859ab92ec 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -32,7 +32,7 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -226,7 +226,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 9871b11b37991..f3dc7dde46bdf 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -17,16 +17,17 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, + PromptReplacement, PromptUpdate, + PromptUpdateDetails, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -311,7 +312,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.boi_token @@ -337,14 +338,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, + mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -373,13 +370,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -404,8 +400,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index a0c3bb50070b3..d59dde1560aea 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -24,16 +24,17 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems, MultiModalDataParser) # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, + PromptReplacement, PromptUpdate, + PromptUpdateDetails, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -209,7 +210,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -254,14 +255,10 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, + mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -290,13 +287,12 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -321,8 +317,8 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 88c53c8363275..662728e6b1393 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -59,7 +59,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -74,7 +74,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from ..layers.activation import SiluAndMul from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .qwen2_vl import _qwen2vl_field_config, apply_rotary_pos_emb_vision +from .qwen2_vl import (_create_qwen2vl_field_factory, + apply_rotary_pos_emb_vision) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -126,7 +127,7 @@ class Glm4vVideoPixelInputs(TensorSchema): - ctpp: Number of channels * temporal_patch_size * patch_size * patch_size - f: Number of frames - - g: Grid dimensions (3 for grid_t which is usually 1 for processed + - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ type: Literal["pixel_values_videos"] = "pixel_values_videos" @@ -141,7 +142,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): - p: Number of video patches across all frames - h: Hidden size (must match language model backbone) - f: Number of frames - - g: Grid dimensions (3 for grid_t which is usually 1 for processed + - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ type: Literal["video_embeds"] = "video_embeds" @@ -234,7 +235,8 @@ class Glm4vVisionAttention(nn.Module): total_num_kv_heads=num_heads, bias=False, quant_config=quant_config, - prefix=f"{prefix}.qkv", + # Change qkv prefix to align with GLM-4.5V-FP8 quantization config + prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", ) self.proj = RowParallelLinear( input_size=projection_size, @@ -1152,13 +1154,15 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -1175,14 +1179,16 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): merge_length = image_processor.merge_size**2 def get_image_replacement_glm4v(item_idx: int): - grid_thw = out_mm_kwargs["image_grid_thw"][item_idx] + out_item = out_mm_kwargs["image"][item_idx] + grid_thw = out_item["image_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length return [hf_processor.image_token_id] * num_tokens def get_video_replacement_glm4v(item_idx: int): - grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + out_item = out_mm_kwargs["video"][item_idx] + grid_thw = out_item["video_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) video, metadata = mm_items["video"][item_idx] diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index aff491f9596c3..fe5e46a99826f 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -131,10 +131,10 @@ class Glm4MoE(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 0624640054d16..322c5619c1783 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -180,14 +180,13 @@ class Glm4MoeMTP(nn.Module, SupportsPP): self, input_ids: torch.Tensor, positions: torch.Tensor, - previous_hidden_states: torch.Tensor, + hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, - previous_hidden_states, inputs_embeds, - spec_step_idx) + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 1751fccd08b06..bf33575859aea 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -503,7 +503,7 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7c7712dbe106e..9c1c05320cf36 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -27,7 +27,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .utils import extract_layer_index, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, + maybe_prefix) class OAIAttention(nn.Module): @@ -159,7 +160,7 @@ class MLPBlock(torch.nn.Module): prefix=f"{prefix}.experts", apply_router_weight_on_input=False, has_bias=True, - activation="swiglu_oai") + activation="swigluoai") def forward(self, x: torch.Tensor) -> torch.Tensor: t = self.norm(x) @@ -173,12 +174,15 @@ class TransformerBlock(torch.nn.Module): def __init__( self, config: GptOssConfig, + cache_config: CacheConfig, quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() self.layer_idx = extract_layer_index(prefix) - self.attn = OAIAttention(config, prefix=f"{prefix}.attn") + self.attn = OAIAttention(config, + prefix=f"{prefix}.attn", + cache_config=cache_config) self.mlp = MLPBlock(config, self.layer_idx, quant_config=quant_config, @@ -202,7 +206,9 @@ class GptOssModel(nn.Module): ): super().__init__() self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config self.quant_config = vllm_config.quant_config + self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, @@ -211,6 +217,7 @@ class GptOssModel(nn.Module): self.layers = torch.nn.ModuleList([ TransformerBlock( self.config, + cache_config=self.cache_config, quant_config=self.quant_config, prefix=maybe_prefix(prefix, f"block.{layer_idx}"), ) for layer_idx in range(self.config.num_hidden_layers) @@ -225,8 +232,364 @@ class GptOssModel(nn.Module): x = self.norm(x) return x + def _load_weights_mxfp4( + self, + ep_rank_end: int, + ep_rank_start: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + mxfp4_block = 32 + use_ep = self.parallel_config.enable_expert_parallel + num_experts = self.config.num_local_experts + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = cdiv(intermediate_size_block, + tp_size) + per_rank_intermediate_size = (per_rank_intermediate_size_block * + mxfp4_block) + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + for name, weight in weights: + # FIXME(woosuk): Remove this after testing. + weight = weight.cuda() + + if ".w13_weight_scale" in name: + # Handle MLP gate and up projection weights scale + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight_scale" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_weight" in name: + # Handle MLP gate and up projection weights + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight" in name: + # Handle MLP down projection weights + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_bias" in name: + # Handle MLP gate and up projection biases + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_bias" in name: + # Handle MLP down projection bias + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + weight_loader(param, + weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break + else: + # Handle all other weights with potential renaming + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + return loaded_params + + def _load_weights_other( + self, + ep_rank_start: int, + ep_rank_end: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + use_ep = self.parallel_config.enable_expert_parallel + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + for name, weight in weights: + if ".w13_weight" in name: + # Handle MLP gate and up projection weights + # Extract gate and up projection parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, :, + 2 * tp_rank_start:2 * tp_rank_end] + + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[name] + + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w2_weight" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[name] + + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w13_bias" in name: + # Handle MLP gate and up projection biases + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[name] + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w2_bias" in name: + # Handle MLP down projection bias + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + param = params_dict[name] + param.copy_(weight) + loaded_params.add(name) + continue + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break + else: + # Handle all other weights with potential renaming + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + return loaded_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv", ".q_proj", "q"), + (".qkv", ".k_proj", "k"), + (".qkv", ".v_proj", "v"), + ] + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # Attention heads per rank + heads_per_rank = self.config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + quant_method = (self.config.quantization_config['quant_method'] if + hasattr(self.config, "quantization_config") else None) + if quant_method == "mxfp4": + return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + else: + return self._load_weights_other(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + class GptOssForCausalLM(nn.Module): + packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".self_attn.": ".attn.", + ".post_attention_layernorm.": ".mlp.norm.", + }, + orig_to_new_suffix={ + ".embed_tokens.weight": ".embedding.weight", + ".input_layernorm.weight": ".attn.norm.weight", + ".post_attention_layernorm.weight": ".mlp.norm.weight", + + # MoE MXFP4 weights + ".gate_up_proj_blocks": ".w13_weight", + ".down_proj_blocks": ".w2_weight", + ".gate_up_proj_scales": ".w13_weight_scale", + ".down_proj_scales": ".w2_weight_scale", + + # MoE other weights + ".gate_up_proj": ".w13_weight", + ".down_proj": ".w2_weight", + + # MoE Bias + ".gate_up_proj_bias": ".w13_bias", + ".down_proj_bias": ".w2_bias", + }, + ) def __init__( self, @@ -235,16 +598,17 @@ class GptOssForCausalLM(nn.Module): ): super().__init__() self.vllm_config = vllm_config - self.model_config = vllm_config.model_config.hf_config + self.config = vllm_config.model_config.hf_config + self.model = GptOssModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"), ) self.lm_head = ParallelLMHead( - self.model_config.vocab_size, - self.model_config.hidden_size, + self.config.vocab_size, + self.config.hidden_size, ) - self.logits_processor = LogitsProcessor(self.model_config.vocab_size) + self.logits_processor = LogitsProcessor(self.config.vocab_size) def forward(self, input_ids: torch.Tensor, @@ -261,354 +625,11 @@ class GptOssForCausalLM(nn.Module): sampling_metadata) return logits - def _load_weights_mxfp4( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - mxfp4_block = 32 - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size - intermediate_size_block = intermediate_size // mxfp4_block - per_rank_intermediate_size_block = cdiv(intermediate_size_block, - tp_size) - per_rank_intermediate_size = (per_rank_intermediate_size_block * - mxfp4_block) - - # Calculate common slicing bounds for current rank - tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) - - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - - for name, weight in weights: - # FIXME(woosuk): Remove this after testing. - weight = weight.cuda() - - if "gate_up_proj_blocks" in name: - # Handle MLP gate and up projection weights - new_name = name.replace("gate_up_proj_blocks", "w13_weight") - - # flat weight from (E, 2 * N, block_size, entry_per_block) - # to (E, 2 * N, -1), shouldn't trigger copy for contiguous - weight = weight.view(num_experts, 2 * intermediate_size, - -1).contiguous() - - # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_blocks" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_blocks", "w2_weight") - # same flatten here, but since 2 mx4 value are packed in 1 - # uint8, divide by 2 - weight = weight.view(num_experts, -1, - intermediate_size // 2).contiguous() - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[..., - tp_rank_start // 2:tp_rank_end // 2] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "gate_up_proj_scales" in name: - # Handle MLP gate and up projection weights scale - new_name = name.replace("gate_up_proj_scales", - "w13_weight_scale") - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_scales" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_scales", "w2_weight_scale") - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[..., tp_rank_start // - mxfp4_block:tp_rank_end // - mxfp4_block] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - elif "gate_up_proj_bias" in name: - # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - - # Extract gate and up projection bias parts - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: - # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - if use_ep: - weight = weight[ep_rank_start:ep_rank_end, ...] - else: - # (only load on rank 0 to avoid duplication) - if tp_rank != 0: - weight.zero_() - weight_loader(param, - weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - elif "sinks" in name: - # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") - param = params_dict[name] - narrow_weight = weight.narrow(0, head_start, heads_per_rank) - param.data.copy_(narrow_weight) - loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) - else: - # Handle all other weights with potential renaming - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: - continue - param = params_dict[renamed_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, weight) - loaded_params.add(renamed_name) - - return loaded_params - - def _load_weights_other( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size - - per_rank_intermediate_size = cdiv(intermediate_size, tp_size) - # Calculate common slicing bounds for current rank - tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) - - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - - for name, weight in weights: - if ".experts.gate_up_proj" in name and "bias" not in name: - # Handle MLP gate and up projection weights - new_name = name.replace(".experts.gate_up_proj", - ".experts.w13_weight") - - # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, :, - 2 * tp_rank_start:2 * tp_rank_end] - - narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif ".experts.down_proj" in name and "bias" not in name: - # Handle MLP down projection weights - new_name = name.replace(".experts.down_proj", - ".experts.w2_weight") - - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] - narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "gate_up_proj_bias" in name: - # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - - # Extract gate and up projection bias parts - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] - - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: - # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - - if use_ep: - weight = weight[ep_rank_start:ep_rank_end, ...] - else: - # (only load on rank 0 to avoid duplication) - if tp_rank != 0: - weight.zero_() - param = params_dict[new_name] - param.copy_(weight) - loaded_params.add(new_name) - elif "sinks" in name: - # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") - param = params_dict[name] - narrow_weight = weight.narrow(0, head_start, heads_per_rank) - param.data.copy_(narrow_weight) - loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) - else: - # Handle all other weights with potential renaming - - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: - continue - param = params_dict[renamed_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, weight) - loaded_params.add(renamed_name) - - return loaded_params - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - quant_method = (self.model_config.quantization_config['quant_method'] - if hasattr(self.model_config, "quantization_config") - else None) - if quant_method == "mxfp4": - return self._load_weights_mxfp4(weights) - else: - return self._load_weights_other(weights) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index c9e3b74e7c3c4..c3ac3bb78c83d 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -40,7 +40,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -118,7 +118,7 @@ class GraniteSpeechMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 5704496b9a5d4..f451e65338b78 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -471,7 +471,10 @@ class GraniteMoeHybridModel(nn.Module): # Mapping different experts' layout: # from HF (input_linear, output_linear, router) # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) - if n.endswith('.block_sparse_moe.input_linear.weight'): + # The renaming and parameter loading logic is the same for weight + # and weight_scale tensors so we can reuse them without issues. + if (n.endswith('.block_sparse_moe.input_linear.weight') or + n.endswith('.block_sparse_moe.input_linear.weight_scale')): for e in range(p.size(0)): w1_name = n.replace( '.block_sparse_moe.input_linear.weight', @@ -490,7 +493,8 @@ class GraniteMoeHybridModel(nn.Module): w3_name, shard_id='w3', expert_id=e) - elif n.endswith('.block_sparse_moe.output_linear.weight'): + elif (n.endswith('.block_sparse_moe.output_linear.weight') or + n.endswith('.block_sparse_moe.output_linear.weight_scale')): for e in range(p.size(0)): w2_name = n.replace( '.block_sparse_moe.output_linear.weight', diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 9e7490e3c4f07..1b3d541c65cf8 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -20,7 +20,7 @@ from vllm.sequence import PoolerOutput from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from .interfaces import SupportsV0Only +from .interfaces_base import default_pooling_type logger = init_logger(__name__) @@ -215,7 +215,8 @@ class GritLMPooler(Pooler): return build_output(pooled_data) -class GritLM(LlamaForCausalLM, SupportsV0Only): +@default_pooling_type("MEAN") +class GritLM(LlamaForCausalLM): """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. The class inherits from LlamaForCausalLM and provides a custom pooling @@ -241,7 +242,6 @@ class GritLM(LlamaForCausalLM, SupportsV0Only): prefix: str = "", **kwargs, ) -> None: - # Use full attention for pooling (this is why V1 is not supported yet) if vllm_config.model_config.runner_type == "pooling": hf_config = vllm_config.model_config.hf_config hf_config.is_causal = False diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index c3e4f81597adb..87e451a2769ea 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -17,11 +17,12 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) -from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import (MultiModalProcessingInfo, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.transformers_utils.tokenizer import AnyTokenizer from .intern_vit import InternVisionModel @@ -425,18 +426,19 @@ class H2OVLMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -477,9 +479,7 @@ class H2OVLMultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only @@ -490,7 +490,6 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) return super()._cached_apply_hf_processor( @@ -498,7 +497,6 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index e5c94c7f3a706..53f0585541b1c 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -53,6 +54,21 @@ IMAGE_TOKEN: str = "<|dummy3|>" VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" +# Based on combine_frames_into_images in +# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py +def get_num_combined_frames( + num_frames: int, + max_grid_shape: tuple[int, int] = (3, 3), +) -> int: + max_num_grids = max_grid_shape[0] * max_grid_shape[1] + + # Calculate the number of canvases needed. + num_canvases = num_frames // max_num_grids + leftover_frames = num_frames % max_num_grids + + return num_canvases + (leftover_frames > 0) + + class HCXVisionMultimodalPixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values_images: list[torch.Tensor] @@ -172,23 +188,20 @@ class HCXVisionMultiModalProcessor( def replace_multimodal_token( token_ids: torch.Tensor, target_token: int, - repeats: list, + repeats: list[int], ): - output = list() + output = list[int]() _repeats_idx = 0 for token_id in token_ids: if token_id == target_token: - output += [ - token_id.item(), - ] * repeats[_repeats_idx] + output += [token_id.item()] * repeats[_repeats_idx] _repeats_idx += 1 else: - output += [ - token_id.item(), - ] + output += [token_id.item()] + return torch.tensor(output, device=token_ids.device) - for video_idx, video_arr in enumerate(mm_data.get("videos", list())): + for video_idx, video_arr in enumerate(mm_data.get("videos", [])): if video_arr.dtype == np.uint8: continue mm_data["videos"][video_idx] = video_arr.astype(np.uint8) @@ -205,88 +218,68 @@ class HCXVisionMultiModalProcessor( if len(mm_data) > 0: # batchify input as a single item images = mm_data.get("images", None) - num_images = 0 - if images is not None: - num_images = len(images) - images = [ - images, - ] # batchify + batched_images = None if images is None else [images] - videos = mm_data.get("videos", - None) # list of video in single conversation - num_videos = 0 - if videos is not None: - num_videos = len(videos) - videos = [ - videos, - ] # batchify + # list of video in single conversation + videos = mm_data.get("videos", None) + batched_videos = None if videos is None else [videos] _processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), data=dict( text=None, - images=images, - videos=videos, + images=batched_images, + videos=batched_videos, ), ) # mm-only for k, v in _processed_outputs.items(): - if len(v) < 1: - continue - elif k.endswith("_images"): - # list of list of 4D tensor -> list of 4D tensor + if isinstance(v, list) and len(v) > 0: + assert len(v) == 1 _processed_outputs[k] = v[0] - elif k.endswith("_videos"): - # list of list of 4D tensor -> list of 4D tensor - v = v[0] - if k == "pixel_values_videos": - v = torch.cat(v, dim=0) - _c, _w, _h = v.shape[-3:] - v = v.reshape(num_videos, -1, _c, _w, _h) - v = list(torch.unbind(v, dim=0)) - _processed_outputs[k] = v - if num_images > 0: + if images: tokenizer = self.info.get_tokenizer() + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) processed_outputs["input_ids"] = torch.stack([ replace_multimodal_token( token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - IMAGE_TOKEN), + target_token=image_token_id, repeats=_processed_outputs[ "vision_query_lengths_images"], ) for _input_ids in processed_outputs["input_ids"] ], dim=0) - if num_videos > 0: - tokenizer = self.info.get_tokenizer() - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - VIDEO_TOKEN), - repeats=_processed_outputs[ - "vision_query_lengths_videos"], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) - - _ratios = [ - len(_pixel_values) for _pixel_values in - _processed_outputs["pixel_values_videos"] - ] + if videos: _num_per_videos = [ - int(_e / sum(_ratios) * - len(_processed_outputs["vision_query_lengths_videos"])) - for _e in _ratios + get_num_combined_frames(len(video)) for video in videos + ] + _processed_outputs["pixel_values_videos"] = [ + _processed_outputs["pixel_values_videos"] + [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] + for _i in range(len(videos)) ] _processed_outputs["vision_query_lengths_videos"] = [ _processed_outputs["vision_query_lengths_videos"] [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(0, num_videos) + for _i in range(len(videos)) ] + tokenizer = self.info.get_tokenizer() + video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN) + processed_outputs["input_ids"] = torch.stack([ + replace_multimodal_token( + token_ids=_input_ids, + target_token=video_token_id, + repeats=[ + sum(lens) for lens in + _processed_outputs["vision_query_lengths_videos"] + ], + ) for _input_ids in processed_outputs["input_ids"] + ], + dim=0) + processed_outputs.update(_processed_outputs) return processed_outputs @@ -295,7 +288,7 @@ class HCXVisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() placeholder = { @@ -306,21 +299,22 @@ class HCXVisionMultiModalProcessor( def get_replacement_hyperclovax( item_idx: int, modality: str, - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ): - num_tokens = None + out_item = out_mm_kwargs[modality][item_idx] + if modality == "image": + lens = out_item["vision_query_lengths_images"].data num_tokens = self.info.get_num_image_tokens( - vision_query_length=out_mm_kwargs[ - "vision_query_lengths_images"][item_idx], ) - if modality == "video": + vision_query_length=lens) + elif modality == "video": + lens = out_item["vision_query_lengths_videos"].data num_tokens = self.info.get_num_video_tokens( - vision_query_length=out_mm_kwargs[ - "vision_query_lengths_videos"][item_idx], ) - assert isinstance(num_tokens, int) - return [ - placeholder[modality], - ] * num_tokens + vision_query_length=lens) + else: + raise NotImplementedError(modality) + + return [placeholder[modality]] * num_tokens return [ PromptReplacement( @@ -374,7 +368,7 @@ def _build_hcxvision_hf_processor( info: HCXVisionProcessingInfo, dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, HCXVisionProcessingInfo): return HCXVisionMultiModalProcessor( @@ -936,8 +930,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): target_group_size = 0 elif video_group_size < target_group_size: - raise RuntimeError(f"video_group_size < target_group_size!! \ - [{video_group_size} < {target_group_size}]") + raise RuntimeError( + f"{video_group_size=} < {target_group_size=}") assert len(target_features ) == 0, f"target_features is not empty!! {target_features}" @@ -1121,9 +1115,8 @@ def reshape_and_unpad_image_features( base_image_feature = image_feature[0] image_feature = image_feature[1:] - assert (height * width == base_image_feature.shape[0] - ), f"height: {height}, width: {width}, \ - base_image_feature.shape[0]: {base_image_feature.shape[0]}" + assert height * width == base_image_feature.shape[0], ( + f"{height=} * {width=} != {base_image_feature.shape[0]=}") num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_size, possible_resolutions, grid_size) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 9e27200fb1c89..88b2a295905b7 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -27,13 +27,15 @@ from transformers.models.idefics2.configuration_idefics2 import ( Idefics2Config, Idefics2VisionConfig) from vllm.attention.layer import MultiHeadAttention -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import run_dp_sharded_vision_model class Idefics2VisionEmbeddings(nn.Module): @@ -118,6 +120,7 @@ class Idefics2VisionAttention(nn.Module): config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config @@ -130,22 +133,43 @@ class Idefics2VisionAttention(nn.Module): f" {self.num_heads}).") self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + assert self.num_heads % tp_size == 0 + self.num_heads_per_partition = self.num_heads // tp_size + + if use_data_parallel: + self.q_size = self.num_heads * self.head_dim + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + 3 * self.q_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = ReplicatedLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -169,18 +193,23 @@ class Idefics2VisionMLP(nn.Module): config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear( + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", ) - self.fc2 = RowParallelLinear( + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2( config.intermediate_size, config.hidden_size, bias=True, @@ -202,17 +231,21 @@ class Idefics2EncoderLayer(nn.Module): config: Idefics2Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.embed_dim = config.hidden_size - self.self_attn = Idefics2VisionAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Idefics2VisionAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Idefics2VisionMLP(config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -254,6 +287,7 @@ class Idefics2Encoder(nn.Module): *, num_hidden_layers_override: Optional[int] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -267,7 +301,8 @@ class Idefics2Encoder(nn.Module): self.layers = nn.ModuleList([ Idefics2EncoderLayer(config, quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(num_hidden_layers) ]) @@ -301,17 +336,20 @@ class Idefics2VisionTransformer(nn.Module): num_hidden_layers_override: Optional[int] = None, require_post_norm: bool = True, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() embed_dim = config.hidden_size self.config = config + self.use_data_parallel = use_data_parallel self.embeddings = Idefics2VisionEmbeddings(config) self.encoder = Idefics2Encoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, - prefix=f"{prefix}.encoder") + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel) num_hidden_layers = config.num_hidden_layers if len(self.encoder.layers) > config.num_hidden_layers: @@ -340,10 +378,38 @@ class Idefics2VisionTransformer(nn.Module): patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, ) - encoder_outputs = self.encoder(hidden_states) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model( + hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state + def _consolidate_qkv_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + qkv_idx_mappings = { + ".self_attn.q_proj": 0, + ".self_attn.k_proj": 1, + ".self_attn.v_proj": 2, + } + qkv_weights = {} + for name, loaded_weight in weights: + for weight_name, idx in qkv_idx_mappings.items(): + if weight_name not in name: + continue + new_name = name.replace(weight_name, ".self_attn.qkv_proj") + if new_name not in qkv_weights: + qkv_weights[new_name] = [None] * 3 + qkv_weights[new_name][idx] = loaded_weight + break + else: + yield name, loaded_weight + for key, weight in qkv_weights.items(): + qkv_weight = torch.cat(weight, dim=0) + yield key, qkv_weight + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -356,6 +422,9 @@ class Idefics2VisionTransformer(nn.Module): loaded_params: set[str] = set() layer_count = len(self.encoder.layers) + if self.use_data_parallel: + weights = self._consolidate_qkv_weights(weights) + for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): @@ -373,7 +442,7 @@ class Idefics2VisionTransformer(nn.Module): continue for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in name or self.use_data_parallel: continue name = name.replace(weight_name, param_name) param = params_dict[name] diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 3c01789b90066..63307470d959b 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -34,7 +34,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageProcessorItems, ImageSize # yapf conflicts with isort for this block # yapf: disable @@ -374,7 +374,7 @@ class Idefics3MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token, _, _ = self.info._get_image_token(hf_processor) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c425488f834b5..506732fed3614 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + supports_encoder_tp_data: ClassVar[bool] = False + """ + A flag that indicates whether this model supports + `multimodal_config.mm_encoder_tp_mode="data"`. + """ + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: """ @@ -137,6 +143,11 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) +def supports_multimodal_encoder_tp_data( + model: Union[type[object], object]) -> bool: + return getattr(model, "supports_encoder_tp_data", False) + + @runtime_checkable class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): """The interface required for all multi-modal models.""" @@ -641,20 +652,6 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) -def default_pooling_type(pooling_type: str) -> object: - """Set default_pooling_type decorator. """ - - def func(model: object): - model.default_pooling_type = pooling_type - return model - - return func - - -def get_default_pooling_type(model: Union[type[object], object]) -> str: - return getattr(model, "default_pooling_type", "LAST") - - class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 697fa020deb46..19a3ef1a3b800 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ + default_pooling_type: ClassVar[str] = "LAST" + """ + Indicates the + [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][] + to use by default. + + You can use the + [vllm.model_executor.models.interfaces_base.default_pooling_type][] + decorator to conveniently set this field. + """ + pooler: Pooler """The pooler is only called on TP rank 0.""" @@ -165,3 +176,20 @@ def is_pooling_model( return False return getattr(model, "is_pooling_model", False) + + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def default_pooling_type(pooling_type: str): + """Decorator to set `VllmModelForPooling.default_pooling_type`.""" + + def func(model: _T) -> _T: + model.default_pooling_type = pooling_type # type: ignore + return model + + return func + + +def get_default_pooling_type(model: Union[type[object], object]) -> str: + return getattr(model, "default_pooling_type", "LAST") diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d0c4bf5450d6d..26bc48ffbd9bc 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -31,7 +31,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index d952ced2fa69f..c739e74b058fa 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -24,7 +24,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -399,7 +399,7 @@ class InternS1MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) img_context_token = hf_processor.image_token @@ -407,15 +407,16 @@ class InternS1MultiModalProcessor( end_image_token = hf_processor.end_image_token video_token = hf_processor.video_token - if "video_num_patches" in out_mm_kwargs: - video_num_patches = out_mm_kwargs["video_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] assert isinstance(video_num_patches, torch.Tensor) video_num_patches = video_num_patches.tolist() else: video_num_patches = [] - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() else: diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 8e766dd4c4768..b09ed7bbe72a3 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -28,7 +28,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -797,18 +797,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -854,9 +855,13 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo): def get_video_token(self) -> Optional[str]: text_model_type = self.get_hf_config().get_text_config().model_type - if text_model_type == "qwen2": - return "<|video_pad|>" - return None + video_token_map = { + "qwen2": "<|video_pad|>", + "qwen3": "<|video_pad|>", + "qwen3_moe": "<|video_pad|>", + "gpt_oss": "<|reserved_200000|>", + } + return video_token_map.get(text_model_type) def get_num_frames_with_most_features( self, @@ -966,15 +971,19 @@ class InternVLMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - prompt_repl: list[PromptUpdate] = super()._get_prompt_updates( - mm_items, hf_processor_mm_kwargs, out_mm_kwargs) + prompt_repl = super()._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "video_num_patches" in out_mm_kwargs: - video_num_patches = out_mm_kwargs["video_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] assert isinstance(video_num_patches, torch.Tensor) video_num_patches = video_num_patches.tolist() else: @@ -992,12 +1001,15 @@ class InternVLMultiModalProcessor( video_context_token=hf_processor.video_token) if self.info.supports_video: - prompt_repl.append( + prompt_repl = [ + *prompt_repl, PromptReplacement( modality="video", target="<video>", replacement=get_video_replacement_internvl, - )) + ) + ] + return prompt_repl diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 0b32d6f256590..3c1a0b68df56e 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -10,6 +10,7 @@ from transformers import JambaConfig from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -154,10 +155,10 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, mamba_cache_params) + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -278,6 +279,7 @@ ALL_DECODER_LAYER_TYPES = { } +@support_torch_compile class JambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 40c66c2268507..c6dbd62b905e1 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -30,10 +30,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) @@ -44,6 +44,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -112,8 +113,9 @@ class KeyeImagePixelInputs(TensorSchema): - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values"] - pixel_values: Annotated[torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps")] + pixel_values: Annotated[ + torch.Tensor, + TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -145,8 +147,9 @@ class KeyeVideoPixelInputs(TensorSchema): - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values_videos"] - pixel_values_videos: Annotated[torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps")] + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] @@ -1189,7 +1192,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -1205,7 +1208,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): merge_length = image_processor.merge_size**2 def get_replacement_keye(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length @@ -1295,7 +1299,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, return None return quant_config - def _validate_and_reshape_mm_tensor(self, mm_input: object, + def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " @@ -1310,8 +1314,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") return torch.concat(list(mm_input)) - else: - return torch.concat(mm_input) + elif is_list_of(mm_input, torch.Tensor): + if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 + for p in mm_input): + return mm_input + return torch.concat(list(mm_input)) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeImageInputs]: diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 1c7ddd7df7f82..a08a9a62a57c5 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -54,8 +54,7 @@ from transformers import BatchFeature from transformers.activations import GELUActivation from vllm.config import VllmConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -63,13 +62,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model -from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.interfaces import (SupportsMultiModal, + SupportsPP) from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -81,7 +81,7 @@ from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .utils import is_pp_missing_parameter, maybe_prefix +from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix # For dummy input only @@ -239,7 +239,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_token_id = self.info.image_token_id @@ -270,7 +270,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): @MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor, info=KimiVLProcessingInfo, dummy_inputs=KimiVLDummyInputsBuilder) -class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -304,17 +305,21 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): prefix=maybe_prefix(prefix, "language_model"), ) self.unpadded_vocab_size = config.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.text_config.hidden_size, - org_num_embeddings=self.config.text_config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.config.text_config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + else: + self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.media_placeholder: int = self.config.media_placeholder_token_id - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_world_size = get_tensor_model_parallel_world_size() # ref: qwen2_vl.py def _validate_and_reshape_mm_tensor(self, mm_input: object, diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py new file mode 100644 index 0000000000000..5f3148b47eadc --- /dev/null +++ b/vllm/model_executor/models/lfm2.py @@ -0,0 +1,557 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Any, Optional + +import torch +import torch.nn as nn +from transformers import Lfm2Config + +from vllm import envs +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Lfm2MLP(nn.Module): + + def __init__( + self, + dim: int, + ff_dim: int, + multiple_of: int, + auto_adjust_ff_dim: bool, + ffn_dim_multiplier: Optional[float], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if auto_adjust_ff_dim: + ff_dim = int(2 * ff_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + ff_dim = int(ffn_dim_multiplier * ff_dim) + ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) + + self.w1 = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[ff_dim] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + input_size=ff_dim, + output_size=dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.w1(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class Lfm2Attention(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + n_tokens, _ = hidden_states.shape + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() + k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() + q = self.q_layernorm(q) + k = self.k_layernorm(k) + q, k = self.rotary_emb(positions, q, k) + q = q.view(n_tokens, self.num_heads * self.head_dim) + k = k.view(n_tokens, self.num_kv_heads * self.head_dim) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class Lfm2AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.prefix = prefix + self.config = config + self.layer_idx = layer_idx + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = Lfm2Attention( + config=config, + layer_idx=layer_idx, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + return self.feed_forward(hidden_states), residual + + +class Lfm2ShortConvDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.conv = ShortConv( + config=config, + dim=config.conv_dim, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.conv", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + output = torch.empty_like(hidden_states) + self.conv( + hidden_states, + output, + conv_metadata=None, + ) + hidden_states, residual = self.ffn_norm(output, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Lfm2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + + def get_layer(prefix: str): + layer_idx = extract_layer_index(prefix) + is_attn = self.config.layer_types[layer_idx] == "full_attention" + layer_class = (Lfm2AttentionDecoderLayer + if is_attn else Lfm2ShortConvDecoderLayer) + return layer_class( + config, + layer_idx, + model_config, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.embedding_norm = RMSNorm(config.hidden_size, + eps=config.norm_eps) + else: + self.embedding_norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.embedding_norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".w1", ".w1", 0), + (".w1", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "w1": [ + "w1", + "w3", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, ...]: + + return MambaStateDtypeCalculator.short_conv_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + """ Calculate shapes for LFM2's convolutional cache. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.conv_dim, + conv_kernel=hf_config.conv_L_cache, + use_v1=use_v1, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "Lfm2 currently does not support prefix caching" + assert envs.VLLM_USE_V1, ( + "Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1") + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + + self.model = Lfm2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 24cd448d8361f..e39a6df843cd4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -31,6 +31,7 @@ from torch import nn from transformers import LlamaConfig from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -173,7 +174,10 @@ class LlamaAttention(nn.Module): if is_sliding: sliding_window = config.sliding_window - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, @@ -349,7 +353,7 @@ class LlamaModel(nn.Module): else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers: tuple[int] = tuple() + self.aux_hidden_state_layers = tuple[int, ...]() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( @@ -549,10 +553,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 308cb3e85e27b..ba08e6f81f7fe 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -195,7 +195,9 @@ class Llama4Attention(nn.Module): is_neox_style=is_neox_style, ) if not self.nope else None - attn_cls = Attention if self.nope else ChunkedLocalAttention + use_chunked_local_attn = not self.nope and config.attention_chunk_size + attn_cls = (ChunkedLocalAttention + if use_chunked_local_attn else Attention) self.attn = attn_cls( self.num_heads, self.head_dim, @@ -206,7 +208,7 @@ class Llama4Attention(nn.Module): prefix=f"{prefix}.attn", **({ "attention_chunk_size": config.attention_chunk_size - } if not self.nope else {})) + } if use_chunked_local_attn else {})) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4927d6b62c6d8..0ee26b68345c3 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -22,14 +22,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalInputs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves @@ -250,7 +250,7 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -343,7 +343,7 @@ class PixtralHFMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() @@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( @@ -795,7 +795,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -807,7 +806,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + tokenization_kwargs) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() @@ -829,26 +828,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): target=[image_token_id] * num_image_tokens, replacement=get_replacement_mantis, ) - ]) + ], mm_item_counts) prompt_ids, prompt, _ = self._apply_prompt_updates( result["prompt_token_ids"], mantis_mm_repls, - mm_item_counts, ) - unbound_orig_repls = self._get_prompt_updates( + orig_repls = self._get_mm_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - orig_repls = self._bind_and_group_updates(unbound_orig_repls) - - mm_placeholders = self._find_mm_placeholders( - orig_repls, - prompt_ids, - mm_item_counts, - ) + mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) mm_placeholder_ranges = { diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index abc519edadcca..cf9852de633f3 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -16,7 +16,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -185,7 +185,7 @@ class LlavaNextVideoMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index ecd24af030a14..e4ac0cd919101 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Final, Literal, Optional, Protocol, TypedDict, Union +from typing import Annotated, Final, Literal, Optional, Protocol, Union import torch import torch.nn as nn @@ -11,18 +11,18 @@ from transformers import (BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor) from transformers.models.llava_onevision.modeling_llava_onevision import ( get_anyres_image_grid_shape, unpad_image) -from typing_extensions import NotRequired from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, _MAX_FRAMES_PER_VIDEO = 16 -class LlavaOnevisionVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]] +class LlavaOnevisionVideoPixelInputs(TensorSchema): """ - Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)` + Dimensions: + - bn: Batch size * number of videos + - f: Number of frames + - c: Number of channels (3) + - h: Height + - w: Width - Note that `num_videos` may be different for each batch, and 'num_frames' - may be different for each video, in which case the data is passed as a - list instead of a batched tensor. + Note that `num_videos` may be different for each batch, and 'num_frames' + may be different for each video, in which case the data is passed as a + list instead of a batched tensor. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" + + pixel_values_videos: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), + ] -class LlavaOnevisionImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class LlavaOnevisionImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + Dimensions: + - bn: Batch size * number of images + - np: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" - image_sizes: NotRequired[torch.Tensor] + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), + ] + + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + + +class LlavaOnevisionImageEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" - -class LlavaOnevisionImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - """ + data: Annotated[ + torch.Tensor, + TensorShape("bn", "ifs", "hs"), + ] LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, @@ -372,7 +390,7 @@ class LlavaOnevisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_repls = super()._get_prompt_updates( mm_items=mm_items, @@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - expected_dims = (2, ) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - f"The expected shape of image sizes per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - - def _validate_image_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionImagePixelInputs( type="pixel_values", - pixel_values=self._validate_image_pixel_values( - flatten_bn(pixel_values)), - image_sizes=self._validate_image_sizes( - flatten_bn(image_sizes, concat=True)), - ) + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size + }) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): @@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, raise AssertionError("This line should be unreachable.") - def _validate_video_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[2:]) - - if actual_dims != expected_dims: - expected_expr = ("num_frames", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values in each video frame " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: @@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=flatten_bn(pixel_values_videos), - ) + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size + }) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f4aaf0c6f467c..f02499a4f96b5 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -9,6 +9,7 @@ from torch import nn from transformers import MambaConfig from vllm import envs +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -81,10 +82,12 @@ class MambaDecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params) + return output, residual +@support_torch_compile class MambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 19afc5be3fb87..5a2079bf5121a 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -164,15 +164,14 @@ class MiMoMTP(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - previous_hidden_states: torch.Tensor, + hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: assert spec_step_idx == 0, "mimo_mtp only support predict one token now" - hidden_states = self.model(input_ids, positions, - previous_hidden_states, inputs_embeds, - spec_step_idx) + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index e1746695bd5db..225668d87facb 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch from torch import nn @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, NestedTensors) from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, @@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, MultiModalDataParser) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, PromptUpdateDetails) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, MiniCPMVDummyInputsBuilder, @@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, CPU_DEVICE = torch.device("cpu") -class MiniCPMOAudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - audio_features: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bns: Batch size * number of audios * number of slices + - bn: Batch size * number of audios + - c: Number of channels + - l: Length + - s: Number of slices + """ + type: Literal["audio_features"] = "audio_features" + + audio_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bns", "c", "l", dynamic_dims={"l"}), + ] """ - Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Slice here means chunk. Audio that is too long will be split into slices, - which is the same as image. - Padding is used therefore `audio_features` is `torch.Tensor`. + which is the same as image. Padding is used therefore `audio_features` is + `torch.Tensor`. """ - audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]] + audio_feature_lens: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s"), + ] """ - Shape: `(batch_size * num_audios, num_slices)` - This should be feature length of each audio slice, which equals to `audio_features.shape[-1]` """ -class MiniCPMOAudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - audio_embeds: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_audios, num_slices, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - instead of a batched tensor. + Dimensions: + - bn: Batch size * number of audios + - s: Number of slices + - h: Hidden size (must match language model backbone) + Length of each slice may vary, so pass it as a list. """ + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s", "h", dynamic_dims={"s"}), + ] MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, @@ -316,7 +334,7 @@ class MiniCPMOMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: base_updates = super()._get_prompt_updates( mm_items=mm_items, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 47ce771d8c901..0181bfeebda08 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -27,12 +27,14 @@ import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial +from itertools import chain from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch import torch.types from torch import nn +from torch.nn.init import trunc_normal_ from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar @@ -47,10 +49,11 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageProcessorItems, ImageSize, ModalityData, ModalityDataItems, @@ -58,7 +61,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, VideoItem, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + PromptUpdate, PromptUpdateDetails, + ResolvedPromptUpdate, _seq2text) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -217,6 +221,187 @@ class Resampler2_5(BaseResampler): return x +class Resampler4_5(Resampler2_5): + + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + max_temporal_size: int = 36000, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + max_size, + quant_config=quant_config, + prefix=prefix) + + trunc_normal_(self.query, std=.02) + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size) + self.apply(self._init_weights) + + def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int, + pos: np.ndarray): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + def _set_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu") -> None: + temporal_size = np.arange(max_temporal_size, dtype=np.float32) + pos_embed = torch.from_numpy( + self.get_1d_sincos_pos_embed_from_temporal_size( + self.embed_dim, temporal_size)).float().to(device) + self.register_buffer("temporal_pos_embed", pos_embed, persistent=False) + + def _adjust_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu"): + if max_temporal_size > self.max_temporal_size: + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size, device) + + def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + # temporal_ids for high refresh rate videos + temporal_ids=None + ) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + temporal_pos_emb = False + temporal_ids_flatten = None + if temporal_ids is not None: + # example: [[-1], [-1], [2, 6, 9]] + temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) + max_temporal_size = max(temporal_ids_flatten, default=0) + if max_temporal_size > -1: + temporal_pos_emb = True + if max_temporal_size > self.max_temporal_size: + self._adjust_temporal_pos_cache(max_temporal_size, device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros((bs, max_patch_len), + dtype=torch.bool, + device=device) + + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + q = self.ln_q(self.query) # Q * D + + pos_embed_2d = [] + pos_embed_temporal = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i] + if temporal_pos_emb: + if temporal_ids_flatten[i] == -1: + pos_embed_temporal.append( + torch.zeros(self.embed_dim, dtype=dtype, + device=device)) + else: + pos_embed_temporal.append(self.temporal_pos_embed[ + temporal_ids_flatten[i]].to(dtype)) # D + + pos_embed_2d.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( + (tgt_h * tgt_w, -1)).to(dtype)) # patches * D + key_padding_mask[i, patch_len[i]:] = True + + pos_embed_2d = torch.nn.utils.rnn.pad_sequence( + pos_embed_2d, batch_first=True, + padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D + + k = x + v = x + pos_embed_2d + if pos_embed_temporal: + k += torch.stack(pos_embed_temporal, dim=0) + bs = len(temporal_ids) + merge_k = [] + merge_v = [] + merge_key_padding_mask = [] + + start = 0 + for tp in temporal_ids: + end = start + len(tp) + # L * (end-start) * D -> (end-start) * L * D + # -> 1 * L*(end-start) * D + merge_k.append(k[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_v.append(v[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_key_padding_mask.append( + key_padding_mask[start:end, :].reshape(-1, 1)) + + start = end + + k = torch.nn.utils.rnn.pad_sequence(merge_k, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + v = torch.nn.utils.rnn.pad_sequence(merge_v, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + key_padding_mask = torch.nn.utils.rnn.pad_sequence( + merge_key_padding_mask, batch_first=True, + padding_value=True).squeeze(-1) + + out = self.attn( + self._repeat(q, bs), # Q * B * D + k, # L * B * D + L * B * D + v, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: version_float = getattr(config, "version", None) @@ -353,9 +538,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: mm_limits = {"image": None} - if self.get_model_version() == (2, - 6) or self.get_model_version() == (4, - 0): + if self.get_model_version() in {(2, 6), (4, 0), (4, 5)}: mm_limits["video"] = None return mm_limits @@ -636,8 +819,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): out_keys: set[str], ) -> dict[str, NestedTensors]: # This processor supports zipping prompt and mm_data together - if self.info.get_model_version() == ( - 2, 6) or self.info.get_model_version() == (4, 0): + if self.info.get_model_version() in {(2, 6), (4, 0), (4, 5)}: inputs = super()._call_hf_processor( prompt=prompts, # type: ignore mm_data=mm_data, @@ -694,7 +876,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: placeholders = [("image", self.info.image_pattern), ("video", self.info.video_pattern)] @@ -744,6 +926,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): for modality, pattern in placeholders ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor() + version = self.info.get_model_version() + + text = _seq2text(tokenizer, cached_update.content.full) + prev_item_idx = cached_update.item_idx + + if version == (2, 0) or version == (2, 5): + im_start = image_processor.im_start_token + im_end = image_processor.im_end_token + else: + im_start = image_processor.im_id_start + im_end = image_processor.im_id_end + + new_update = new_update.with_content( + PromptUpdateDetails.select_text( + text.replace( + f"{im_start}{prev_item_idx}{im_end}", + f"{im_start}{new_item_idx}{im_end}", + 1, + ), + "<unk>", + )) + + return new_update + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -1302,6 +1521,8 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): ], } + supports_encoder_tp_data = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (4, 0) @@ -1395,11 +1616,121 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): return loader.load_weights(weights) +class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (4, 5) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): + return None + return quant_config + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return Qwen3ForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + model = Idefics2VisionTransformer(config.vision_config, + quant_config=quant_config, + prefix=prefix) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + with set_default_torch_dtype(torch.float16): + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. + resampler = Resampler4_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + temporal_ids = data.get('temporal_ids', None) + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + all_temporal_ids = None if temporal_ids is None else flatten_2d_lists( + temporal_ids) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + + _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, (2, 5): MiniCPMV2_5, (2, 6): MiniCPMV2_6, (4, 0): MiniCPMV4_0, + (4, 5): MiniCPMV4_5, } diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 82e96844cd5f6..0e854bd7d913d 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -4,7 +4,10 @@ import copy import math from collections.abc import Iterable -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import regex as re import torch @@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def mamba_type(self) -> str: return "linear_attention" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.linear_attn import ( + LinearAttentionBackend) + return LinearAttentionBackend + def get_state_dtype(self) -> tuple[torch.dtype]: return MambaStateDtypeCalculator.linear_attention_state_dtype( self.model_config.dtype, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 8107c6e8a04a1..cc7db849a28bf 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import torch import torch.nn as nn from transformers import BatchFeature, PretrainedConfig +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn @@ -17,6 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -29,24 +32,36 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -class MiniMaxVL01ImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class MiniMaxVL01ImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, height, width)` + Dimensions: + - bn: Batch size * number of images + - np: Number of patches + 1 + - c: Number of channels (3) + - h: Height + - w: Width - Note that `height` or `width` may be different per batch and image, + Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})] + + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + # This should be in `(height, width)` format. -class MiniMaxVL01ImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class MiniMaxVL01ImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs, @@ -141,6 +156,7 @@ class MiniMaxVL01MultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return { "pixel_values": MultiModalFieldConfig.batched("image"), + "image_sizes": MultiModalFieldConfig.batched("image"), "image_embeds": MultiModalFieldConfig.batched("image"), } @@ -239,7 +255,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) + image_features = tuple(vision_tower(p) for p in pixel_values) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -252,6 +268,56 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, json_map_leaves(select_features, image_features), ) + # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 + def pack_image_features(self, image_features: list[torch.Tensor], + image_sizes: torch.Tensor): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = (self.config.vision_config.image_size // + self.config.vision_config.patch_size) + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with " + "the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view(num_patch_height, + num_patch_width, height, + width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, + 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, + image_sizes[image_idx]) + + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1).to( + image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), + dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, + self.image_newline[None].to(image_feature)), + dim=0) + new_image_features.append(image_feature) + return new_image_features + def _process_image_pixels( self, inputs: MiniMaxVL01ImagePixelInputs, @@ -259,7 +325,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, assert self.vision_tower is not None pixel_values = inputs["pixel_values"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) def _process_image_input( @@ -281,38 +346,31 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = self.multi_modal_projector(torch.cat(image_features)) image_embeds = torch.split(image_embeds, feature_sizes) - return image_embeds - - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data + image_sizes = image_input.get("image_sizes") + return self.pack_image_features(image_embeds, image_sizes) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None - if pixel_values is not None: + if pixel_values is not None and image_sizes is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + return MiniMaxVL01ImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), ) if image_embeds is not None: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 9e29a96c6e44a..08948960b275c 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union) import torch @@ -22,16 +22,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -42,15 +43,23 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .vision import get_vision_encoder_info -class Mistral3ImagePixelInputs(TypedDict): - type: Literal["pixel_values_pixtral"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class Mistral3ImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image """ - Shape: `(batch_size * num_images, num_channels, height, width)` - Note that `height` or `width` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. - """ + type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" + + # Note that `height` or `width` may be different per batch and image, + # in which case the data is passed as a list instead of a batched tensor. + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), + ] class Mistral3PatchMerger(nn.Module): @@ -265,7 +274,7 @@ class Mistral3MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() @@ -313,7 +322,7 @@ def _build_mistral3_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: assert isinstance(info, Mistral3ProcessingInfo) return Mistral3MultiModalProcessor( @@ -456,19 +465,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 30ae3f26c8e45..2a60450de4141 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -17,7 +17,7 @@ """PyTorch Mllama model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import numpy as np import torch @@ -56,13 +56,15 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalKwargs) + MultiModalFieldConfig, + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPMLP from .interfaces import SupportsMultiModal, SupportsV0Only @@ -72,15 +74,30 @@ from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = init_logger(__name__) -class MllamaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: """ - """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" - aspect_ratio_ids: torch.Tensor - """Shape: `(batch_size, max_num_image)`""" - aspect_ratio_mask: torch.Tensor - """Shape: `(batch_size, max_num_image, max_num_tiles)`""" +class MllamaImagePixelInputs(TensorSchema): + """ + Dimensions: + - batch_size: Batch size + - max_num_image: Max number of images + - max_num_chunk: Max number of chunks + - max_num_tiles: Max number of tiles per image + - num_channel: Number of channels + - height: Height + - width: Width + """ + + type: Literal["pixel_values"] = "pixel_values" + + data: Annotated[torch.Tensor, + TensorShape("batch_size", "max_num_image", "max_num_chunk", + "num_channel", "height", "width")] + + aspect_ratio_ids: Annotated[torch.Tensor, + TensorShape("batch_size", "max_num_image")] + + aspect_ratio_mask: Annotated[ + torch.Tensor, + TensorShape("batch_size", "max_num_image", "max_num_tiles")] # TODO: support LlamaImageEmbeddingInputs @@ -167,10 +184,9 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalEncDecInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + tokenization_kwargs) image_token_id = self.info.get_hf_config().image_token_index # Check that the number of image tokens in the decoder prompt matches @@ -217,7 +233,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] # Set encoder prompt length based on the number of tiles. # This tells the block manager to allocate correct number # of slots for encoder tokens. - num_tiles = mm_inputs["mm_kwargs"]["num_tiles"] + num_tiles = mm_inputs["mm_kwargs"].get_data()["num_tiles"] decode_tiles = num_tiles[num_encode_images:num_images].sum().item() num_tokens = decode_tiles * token_per_chunk mm_inputs["encoder_prompt_token_ids"] = [image_token_id @@ -302,7 +318,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: token_per_chunk = self.info.get_token_per_chunk_from_config() image_token_id = self.info.get_hf_config().image_token_index diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index b405dfca6d39b..ac9b968f7a0cd 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -646,13 +646,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptUpdate]: - assert ( - mm_items.get_count("image", strict=False) == 0 - or "aspect_ratios" in out_mm_kwargs - ), "Transformers expect to include aspect_ratios in out_mm_kwargs" - config = self.info.get_hf_config() vision_config = config.vision_config @@ -662,7 +657,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] img_patch_token = hf_processor.img_patch_token def get_replacement(item_idx: int): - aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx] + out_item = out_mm_kwargs["image"][item_idx] + aspect_ratio = out_item["aspect_ratios"].data repl = hf_processor._prompt_split_image( aspect_ratio=aspect_ratio, @@ -720,6 +716,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, "gate_up_proj": ["gate_proj", "up_proj"], } + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -732,8 +730,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index c6e84e2d4e040..4778555861286 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -7,7 +7,7 @@ import torch from torch import nn from transformers import ModernBertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type from .utils import WeightsMapper, maybe_prefix @@ -104,12 +105,12 @@ class ModernBertAttention(nn.Module): head_size=self.head_dim, dim=self.head_dim, base=rope_theta) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - prefix=f"{layer_id}.attn", - attn_type=AttentionType.ENCODER_ONLY, - per_layer_sliding_window=sliding_window) + self.attn = EncoderOnlyAttention( + self.num_heads, + self.head_dim, + self.scaling, + prefix=f"{layer_id}.attn", + per_layer_sliding_window=sliding_window) self.Wo = RowParallelLinear(config.hidden_size, config.hidden_size, bias=config.attention_bias) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 78dc0dca957f0..5fc28ed0e493e 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import Optional, TypedDict, Union +from typing import Annotated, Optional, Union import numpy as np import torch @@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -51,6 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -70,23 +71,25 @@ IM_END_TOKEN = "<im_end>" POOLING_SIZE = 2 -class MolmoImageInputs(TypedDict): - images: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`""" - - image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] - """Shape: `(batch_size * num_images, num_crops, num_patch)`""" - - feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] +class MolmoImageInputs(TensorSchema): """ - A boolean mask indicating which image features correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_crops, num_patch)` + Dimensions: + - bn: Batch size * number of images + - nc: Number of crops + - np: Number of patches + - pd: Patch dimension """ + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np", "pd")] - num_crops: torch.Tensor - """Shape: `(batch_size * num_images)`""" + image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], + TensorShape("bn", "nc", "np")] + + feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np")] + # A boolean mask indicating which image features correspond to patch tokens. + + num_crops: Annotated[torch.Tensor, TensorShape("bn")] @dataclass @@ -1282,7 +1285,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -1410,28 +1413,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, **kwargs: object, ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) + image_masks = kwargs.pop("image_masks", None) + feat_is_patch = kwargs.pop("feat_is_patch", None) + num_crops = kwargs.pop("num_crops", None) + if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - - image_masks = kwargs.pop("image_masks", None) - if not (image_masks is None or isinstance(image_masks, - (torch.Tensor, list))): - raise ValueError("Incorrect type of image_masks. " - f"Got type: {type(image_masks)}") - - feat_is_patch = kwargs.pop("feat_is_patch", None) - if not isinstance(feat_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of feat_is_patch. " - f"Got type: {type(feat_is_patch)}") - - num_crops = kwargs.pop("num_crops", None) if not isinstance(num_crops, (torch.Tensor, list)): raise ValueError("Incorrect type of num_crops. " f"Got type: {type(num_crops)}") + num_crops = flatten_bn(num_crops, concat=True) img_patch_id = kwargs.pop("img_patch_id", None) if not isinstance(img_patch_id, torch.Tensor): @@ -1439,8 +1431,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, f"Got type: {type(img_patch_id)}") self.img_patch_id = img_patch_id.flatten().unique().item() - num_crops = flatten_bn(num_crops, concat=True) - return MolmoImageInputs( images=images, image_masks=image_masks, diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 82bcd064624f3..a9c7d8044e10c 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -458,27 +458,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - #use force_image_size to get image_size - h = w = self.config.force_image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) @@ -516,9 +495,12 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return InternVLImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, + resolve_bindings={ + "h": self.config.force_image_size, + "w": self.config.force_image_size + }, ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 4bea1392a6814..3bbf4c67604c7 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -16,7 +16,7 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, @@ -106,18 +106,19 @@ class NVLMMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 1dc4df85c1bc4..01639d398126f 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -91,6 +91,7 @@ class OlmoAttention(nn.Module): self.total_num_heads, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) # Rotary embeddings. @@ -114,6 +115,7 @@ class OlmoAttention(nn.Module): self.hidden_size, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) def forward( @@ -142,6 +144,7 @@ class OlmoMLP(nn.Module): self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -154,6 +157,7 @@ class OlmoMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) # Activation function. @@ -165,6 +169,7 @@ class OlmoMLP(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) def forward( @@ -197,7 +202,7 @@ class OlmoDecoderLayer(nn.Module): prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = OlmoMLP(config, quant_config) + self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp") # LayerNorm self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -326,10 +331,21 @@ class OlmoModel(nn.Module): return loaded_params -class OlmoForCausalLM(nn.Module, SupportsPP): +class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 499e6d30ed6b0..66a0f9115585a 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -33,6 +33,7 @@ from torch import nn from transformers import Olmo2Config from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.communication_op import tensor_model_parallel_all_gather @@ -48,7 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -253,6 +254,7 @@ class Olmo2DecoderLayer(nn.Module): return hidden_states +@support_torch_compile class Olmo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -354,10 +356,21 @@ class Olmo2Model(nn.Module): return loaded_params -class Olmo2ForCausalLM(nn.Module, SupportsPP): +class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 6b27980e0b0c3..5b3ad7cbd07ad 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -42,7 +42,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) @@ -375,11 +375,12 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptReplacement]: - def get_replacement_ovis(item_idx): - grid = out_mm_kwargs["grids"][item_idx] + def get_replacement_ovis(item_idx: int): + out_item = out_mm_kwargs["image"][item_idx] + grid = out_item["grids"].data hf_processor = self.info.get_hf_processor() return hf_processor.construct_image_placeholders(grid) diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py new file mode 100644 index 0000000000000..58a14072443cb --- /dev/null +++ b/vllm/model_executor/models/ovis2_5.py @@ -0,0 +1,574 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" PyTorch Ovis model.""" +from collections.abc import Iterable, Mapping +from functools import partial +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.ovis import (OvisImagePatchInputs, + VisualEmbedding) +from vllm.model_executor.models.siglip2navit import Siglip2NavitModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, + init_vllm_registered_model, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP + +IMAGE_TOKEN = "<image>" +VIDEO_TOKEN = "<video>" +INDICATOR_IDS = [-301, -302, -303, -304] + +IMAGE_PAD_TOKEN_MAP = { + "gemma2": "<unused0>", + "llama": "<|reserved_special_token_0|>", + "qwen2": "<|image_pad|>", + "qwen3": "<|image_pad|>", +} +IMAGE_PAD_TOKEN_ID_MAP = { + "gemma2": 7, + "llama": 128002, + "qwen2": 151655, + "qwen3": 151655, +} + + +def _ovis2_5_field_config(): + return dict(pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), + video_indicator_tokens=MultiModalFieldConfig.batched("video"), + video_grids=MultiModalFieldConfig.batched("video")) + + +class VisualTokenizer(torch.nn.Module): + """ + VIT + """ + + def __init__( + self, + config: PretrainedConfig, + visual_vocab_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.vit = self._init_backbone( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.vit", + use_data_parallel=use_data_parallel, + ) + # reserved tokens for INDICATOR_IDS + head_dim = visual_vocab_size - len(INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + self.config.hidden_size * self.config.hidden_stride**2, + head_dim, + bias=False, + return_bias=False, + ), torch.nn.LayerNorm(head_dim)) + + def _init_backbone( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + model_type = config.model_type + if model_type == "siglip2_navit": + return Siglip2NavitModel(config=config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=use_data_parallel) + raise ValueError( + f"Unsupported visual tokenizer model_type: {model_type}") + + @property + def dtype(self) -> torch.dtype: + return next(self.head.parameters()).dtype + + @property + def device(self) -> torch.device: + return next(self.head.parameters()).device + + def tokenize(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.softmax(logits, dim=-1, + dtype=torch.float32).to(logits.dtype) + return tokens + + def encode(self, pixel_values: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + features = self.vit(pixel_values, grid_thws) + # refer to qwen2.5-vl patchmerger + seq_len, _ = features.shape + features = features.reshape(seq_len // (self.config.hidden_stride**2), + -1) + + return features + + def forward(self, pixel_values: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + features = self.encode(pixel_values, grid_thws) + logits = self.head(features) + tokens = self.tokenize(logits) + # tokens' shape is [#Token, VocabSize-4], + # so padding with [#Token, 4], after which, + # tokens' shape should become [#Token, VocabSize]; + tokens = torch.nn.functional.pad( + tokens, + (0, len(INDICATOR_IDS)), + mode="constant", + value=0, + ) + return tokens + + +class Ovis2_5ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs): + vit_config = self.get_hf_config().vit_config + return self.ctx.get_hf_processor( + Ovis2_5Processor, + image_pad_token=self.get_image_pad_token(), + patch_size=vit_config.patch_size, + hidden_stride=vit_config.hidden_stride, + temporal_patch_size=vit_config.temporal_patch_size, + ) + + def get_image_pad_token(self) -> str: + hf_text_config = self.get_hf_config().get_text_config() + text_model_type = hf_text_config.model_type + return IMAGE_PAD_TOKEN_MAP.get(text_model_type) + + def get_image_processor(self) -> BaseImageProcessor: + return self.get_hf_processor().image_processor # type: ignore + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + def get_image_size_with_most_features(self) -> ImageSize: + # NOTE(myselvess): max_pixels 1792 * 1792 hardcoded in original code + # TODO(myselvess): Be adjusted based on the max_pixels + return ImageSize(width=1792, height=1792) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + ) -> tuple[ImageSize, int]: + hf_config = self.get_hf_config() + vit_config = hf_config.vit_config + patch_size = vit_config.patch_size + temporal_patch_size = vit_config.temporal_patch_size + # NOTE: Frames are padded to be divisible by `temporal_patch_size` + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 + padded_num_frames = num_frames + (-num_frames % temporal_patch_size) + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = image_height // patch_size + grid_w = image_width // patch_size + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches + return num_vision_tokens + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_width=target_width, + image_height=target_height) + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + num_frames = 0 + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + if next_max_tokens > max_tokens: + break + num_frames = next_num_frames + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = max_total_frames // max(max_videos, 1) + return max(max_frames_per_video, 1) + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[BaseImageProcessor], + ) -> int: + num_video_tokens = self.get_num_image_tokens(image_width=image_width, + image_height=image_height, + num_frames=num_frames) + return num_video_tokens + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ) + } + return mm_data + + +class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] + ): + + def visual_indicators_to_visual_tokens( + self, + visual_indicators: list[int], + ) -> list[int]: + """ + Filter image indicators placeholders and convert them to corresponding + tokens in visual tokenizer. + """ + hf_config = self.info.get_hf_config() + vte_vocab_size = hf_config.visual_vocab_size + return [ + vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1 + for x in visual_indicators if x < -300 + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # Avoid warning from HF logger for text-only input + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + hf_processor = self.info.get_hf_processor() + + if "videos" in mm_data: + visual_indicators = [ + hf_processor.construct_visual_indicators((1, 1, 1), True) + for grid in processed_outputs["video_grids"] + ] + indicator_tokens = [ + self.visual_indicators_to_visual_tokens(indicator) + for indicator in visual_indicators + ] + processed_outputs["video_indicator_tokens"] = indicator_tokens + if "images" in mm_data: + visual_indicators = [ + hf_processor.construct_visual_indicators((1, 1, 1), False) + for grid in processed_outputs["grids"] + ] + indicator_tokens = [ + self.visual_indicators_to_visual_tokens(indicator) + for indicator in visual_indicators + ] + + processed_outputs["indicator_tokens"] = indicator_tokens + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + + return prompt_tokens + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _ovis2_5_field_config() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> list[PromptReplacement]: + + def get_replacement_ovis(item_idx, modality: str): + if modality == "image": + out_item = out_mm_kwargs["image"][item_idx] + grid = out_item["grids"].data + elif modality == "video": + out_item = out_mm_kwargs["video"][item_idx] + grid = out_item["video_grids"].data + hf_processor = self.info.get_hf_processor() + return hf_processor.construct_visual_placeholders(grid[0], ) + + return [ + PromptReplacement( + modality=modality, + target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN, + replacement=partial(get_replacement_ovis, modality=modality), + ) for modality in ("image", "video") + ] + + +@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor, + info=Ovis2_5ProcessingInfo, + dummy_inputs=Ovis2_5DummyInputsBuilder) +class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config: PretrainedConfig = config + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "llm"), + ) + + self.visual_tokenizer = VisualTokenizer( + config=config.vit_config, + visual_vocab_size=config.visual_vocab_size, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + ) + + self.vte = VisualEmbedding(config.visual_vocab_size, + config.hidden_size) + + text_model_type = self.config.get_text_config().model_type + self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] + + self.make_empty_intermediate_tensors = ( + self.get_language_model().make_empty_intermediate_tensors) + + def _parse_and_validate_visual_input( + self, is_video, + **kwargs: object) -> Optional[OvisImagePatchInputs]: + if is_video: + pixel_values = kwargs.pop("video_pixel_values", None) + indicator_tokens = kwargs.pop("video_indicator_tokens", None) + grids = kwargs.pop("video_grids", None) + else: + pixel_values = kwargs.pop("pixel_values", None) + indicator_tokens = kwargs.pop("indicator_tokens", None) + grids = kwargs.pop("grids", None) + if pixel_values is None and indicator_tokens is None: + return None + + if pixel_values is not None and indicator_tokens is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(indicator_tokens, (torch.Tensor, list)): + raise ValueError("Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}") + + return OvisImagePatchInputs( + type="image_patches", + flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + patches_per_image=[ + x.shape[0] // (self.config.vit_config.hidden_stride**2) + for x in flatten_bn(pixel_values) + ], + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), + concat=True), + grids=flatten_bn(flatten_bn(grids), concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: + image_patches_flat = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"] + indicator_tokens = image_input["indicator_tokens"] + grid_thws = image_input["grids"] + + indicator_per_image = list( + map(lambda x: 2 if x > 1 else x + 2, patches_per_image)) + + target_dtype = self.visual_tokenizer.dtype + visual_tokens = self.visual_tokenizer( + image_patches_flat.to(target_dtype), grid_thws) + + visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. + indicator_embeds = self.vte(indicator_tokens) + + visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) + indicator_embeds_per_image = indicator_embeds.split( + indicator_per_image) + + vision_embeddings = [] + for indicator, visual in zip(indicator_embeds_per_image, + visual_embeds_per_image): + vision_embeddings_per_image = [] + visual = visual.unsqueeze(0) + for i in range(visual.shape[0]): + vision_embeddings_per_image.append( + torch.cat([indicator[i:i + 1], visual[i]], dim=0)) + vision_embeddings_per_image.append(indicator[i + 1:]) + vision_embeddings.append( + torch.cat(vision_embeddings_per_image, dim=0)) + return tuple(vision_embeddings) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + embeddings = [] + + # NOTE: _parse_and_validate_visual_input has side-effects and pops + # keys from kwargs. We process images first, then videos. + image_input = self._parse_and_validate_visual_input(False, **kwargs) + if image_input: + embeddings.extend(self._process_image_input(image_input)) + + video_input = self._parse_and_validate_visual_input(True, **kwargs) + if video_input: + embeddings.extend(self._process_image_input(video_input)) + + return tuple(embeddings) if embeddings else None + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.llm.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + tmp = torch.concat(multimodal_embeddings, dim=0) + inputs_embeds[input_ids == self.image_pad_token_id] = tmp + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + # up until here we have a inputs_embeds 100% numerical identity + # between the OG HF Transformers implementation and ours + hidden_states = self.llm( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.llm.compute_logits(hidden_states, sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_language_model(self) -> torch.nn.Module: + return self.llm diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index b1f2e53b0c712..95abb190e0a46 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalInputs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -21,6 +21,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel @@ -32,19 +33,27 @@ from .vision import get_vision_encoder_info logger = init_logger(__name__) -class PaliGemmaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" - - -class PaliGemmaImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class PaliGemmaImagePixelInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + + +class PaliGemmaImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, @@ -146,7 +155,7 @@ class PaliGemmaMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -194,10 +203,9 @@ class PaliGemmaMultiModalProcessor( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + tokenization_kwargs) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() @@ -280,19 +288,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -302,22 +297,17 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_values = flatten_bn(pixel_values, concat=True) - return PaliGemmaImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values(pixel_values), - ) + h = w = self.config.vision_config.image_size + return PaliGemmaImagePixelInputs(type="pixel_values", + data=pixel_values, + resolve_bindings={ + "h": h, + "w": w + }) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - image_embeds = flatten_bn(image_embeds, concat=True) return PaliGemmaImageEmbeddingInputs( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 9ef4f8371eb3d..4522c7043d01a 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -32,15 +32,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate) + PromptReplacement, PromptUpdate, + ResolvedPromptUpdate) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -410,7 +412,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_tokens: list[str] = hf_processor.img_tokens # type: ignore @@ -431,24 +433,38 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): return [_IMAGE_TOKEN_ID] * num_image_tokens - num_images = mm_items.get_count("image", strict=False) - return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:num_images] + ) ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + hf_processor = self.info.get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + + return new_update + def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: # align to hf behavior when there are images - if len(mm_item_counts): + if len(mm_prompt_updates): tokenizer = self.info.get_tokenizer() # to decode token_ids to the original text, we need to # 1. remove the first bos token @@ -484,7 +500,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): token_ids, text, placeholders = super()._apply_prompt_updates( token_ids=token_ids, mm_prompt_updates=mm_prompt_updates, - mm_item_counts=mm_item_counts, ) # Keep the behavior in line with HF processor diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index e13b8276bf17a..492d4bfb7d3e6 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -30,7 +30,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, MultiModalDataParser) @@ -1029,11 +1029,11 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.vocab[tokenizer.image_token] - audio_token_id = tokenizer.vocab[tokenizer.audio_token] + image_token_id: int = tokenizer.vocab[tokenizer.image_token] + audio_token_id: int = tokenizer.vocab[tokenizer.audio_token] hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) audio_processor = self.info.get_feature_extractor( @@ -1053,9 +1053,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): processor=hf_processor, ) - image_tokens = [image_token_id] * num_image_tokens - - return image_tokens + return [image_token_id] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) @@ -1066,9 +1064,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): audio_embed_size = self.info._compute_audio_embed_size( audio_frames) - audio_tokens = [audio_token_id] * audio_embed_size - - return audio_tokens + return [audio_token_id] * audio_embed_size return [ PromptReplacement( diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 493a4192d35ad..fcdfcb7bc1603 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -650,8 +650,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): num_mamba_layers = self.config.num_hidden_layers \ // 2 // self.config.mb_per_layer + 1 self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + self.vllm_config, + num_mamba_layers, + *self._get_mamba_cache_shape(), + self.lm_head.weight.dtype, + self.lm_head.weight.dtype, + ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) attn_metadata = get_forward_context().attn_metadata diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 73e8446e6dea7..211cbd9c819cc 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -21,13 +21,13 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate) + PromptUpdate, ResolvedPromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -262,9 +262,9 @@ class Phi4MMImageEncoder(nn.Module): img_features.shape[1])) assert base_feat_height == base_feat_height_target \ and base_feat_width == base_feat_height_target, \ - f'base_feat_height: {base_feat_height},"\ - f" base_feat_width: {base_feat_width}, "\ - f"expect {base_feat_height_target} features for hd transform' + (f"base_feat_height: {base_feat_height}, " + f"base_feat_width: {base_feat_width}, " + f"expect {base_feat_height_target} features for hd transform") # bs x max_num_crops x (24x24) x C img_features = img_features.view(bs, -1, @@ -802,7 +802,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_tokens: list[str] = self.info.image_tokens # type: ignore audio_tokens: list[str] = self.info.audio_tokens # type: ignore @@ -824,9 +824,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): processor=hf_processor, ) - image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens - - return image_tokens + return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) @@ -837,28 +835,39 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): audio_embed_size = self.info._compute_audio_embed_size( audio_frames) - audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - return audio_tokens - - num_images = mm_items.get_count("image", strict=False) - num_audios = mm_items.get_count("audio", strict=False) - - image_repl = [ + return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_image_replacement_phi4mm, - ) for image_token in image_tokens[:num_images] - ] - audio_repl = [ + ), PromptReplacement( modality="audio", - target=audio_token, + target=audio_tokens.__getitem__, replacement=get_audio_replacement_phi4mm, - ) for audio_token in audio_tokens[:num_audios] + ), ] - return image_repl + audio_repl + + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + image_tokens: list[str] = self.info.image_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + elif cached_update.modality == "audio": + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + new_update = new_update.with_target(audio_tokens[new_item_idx]) + + return new_update @MULTIMODAL_REGISTRY.register_processor( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 41eaf372785eb..461b9c85d1c22 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -33,13 +33,14 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs @@ -47,6 +48,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -67,15 +69,20 @@ except ImportError: PATCH_MERGE = "patch_merge" -class PixtralImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - - images: Union[torch.Tensor, list[torch.Tensor]] +class PixtralImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, image_width, image_height)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + The result of stacking `ImageEncoding.tokens` from each prompt. """ + type: Literal["pixel_values"] = "pixel_values" + + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})] class PixtralProcessorAdapter: @@ -273,7 +280,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -307,24 +314,16 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: - ( - prompt_ids, - mm_kwargs, - mm_hashes, - _, - ) = super()._cached_apply_hf_processor( + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, mm_hashes, True + return prompt_ids, mm_info, True @MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, @@ -388,10 +387,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 8b1df66f02805..e5034b536266a 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -767,8 +767,12 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + self.vllm_config, + num_mamba_layers, + *self._get_mamba_cache_shape(), + self.lm_head.weight.dtype, + self.lm_head.weight.dtype, + ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 20f423cc7603d..f46d6375e1f61 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -18,7 +18,7 @@ """Inference-only IBM/NASA Prithvi Geospatial model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, Union +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -27,21 +27,61 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import ( - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput, - default_pooling_type) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalFieldElem, MultiModalInputs, - MultiModalKwargs, MultiModalKwargsItem, - MultiModalSharedField, PlaceholderRange) -from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargsItems, + PlaceholderRange) +from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from .interfaces import (IsAttentionFree, MultiModalEmbeddings, + SupportsMultiModalWithRawInput) +from .interfaces_base import default_pooling_type + + +def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): + # This model receives in input a multi-dimensional tensor representing + # a single image patch and therefore it is not to be split + # into multiple elements, but rather to be considered a single one. + # Hence, the decision of using a MultiModalSharedField. + # The expected shape is (num_channels, width, height). + + # This model however allows the user to also submit multiple image + # patches as a batch, adding a further dimension to the above shape. + # At this stage we only support submitting one patch per request and + # batching is achieved via vLLM batching. + # TODO (christian-pinto): enable support for multi patch requests + # in tandem with vLLM batching. + return dict( + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + ) + + +class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="image", + required_fields={"pixel_values", "location_coords"}, + fields_factory=_prithvi_field_config, + ) + + return super()._parse_image_data(data) + class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): @@ -63,32 +103,32 @@ class PrithviGeoSpatialMAEInputBuilder( # This model input is fixed and is in the form of a torch Tensor. # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. - return { + image_data = { "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } + return {"image": image_data} + class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): + def _get_data_parser(self) -> MultiModalDataParser: + return PrithviGeoSpatialMAEMultiModalDataParser() + def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - ) + return _prithvi_field_config(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: return [] @@ -98,46 +138,32 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: - mm_kwargs = {} + if "image" in mm_data: + image_data = mm_data["image"] + else: + image_data = mm_data + mm_data = {"image": mm_data} - for k, v in mm_data.items(): - if isinstance(v, dict) and k == "image": - mm_kwargs.update(v) - else: - mm_kwargs[k] = v + mm_items = self._to_mm_items(mm_data) + mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, + tokenization_kwargs or {}) mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} - # This model receives in input a multi-dimensional tensor representing - # a single image patch and therefore it is not to be split - # into multiple elements, but rather to be considered a single one. - # Hence, the decision of using a MultiModalSharedField. - # The expected shape is (num_channels, width, height). + mm_processed_data = BatchFeature(image_data) - # This model however allows the user to also submit multiple image - # patches as a batch, adding a further dimension to the above shape. - # At this stage we only support submitting one patch per request and - # batching is achieved via vLLM batching. - # TODO (christian-pinto): enable support for multi patch requests - # in tandem with vLLM batching. - multimodal_kwargs_items = [ - MultiModalKwargsItem.from_elems([ - MultiModalFieldElem( - modality="image", - key=key, - data=data, - field=MultiModalSharedField(1), - ) for key, data in mm_kwargs.items() - ]) - ] + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + mm_processed_data, + self._get_mm_fields_config(mm_processed_data, + hf_processor_mm_kwargs), + ) return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), - mm_hashes=None, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, mm_placeholders=mm_placeholders, ) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7304fbf120ccd..27c1e68c6704b 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -32,6 +32,7 @@ from torch import nn from transformers import Qwen2Config from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -51,7 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module): rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, @@ -330,7 +333,7 @@ class Qwen2Model(nn.Module): else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers: tuple[int] = tuple() + self.aux_hidden_state_layers = tuple[int, ...]() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -439,7 +442,7 @@ class Qwen2Model(nn.Module): return loaded_params -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -485,6 +488,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index e95295c31885a..5c64c81547e65 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -25,7 +25,7 @@ from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -47,18 +47,19 @@ from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioInputs, Qwen2AudioProcessingInfo, + Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalPromptUpdates, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -78,37 +79,57 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) -def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): - audio_feature_lengths = hf_inputs.get("audio_feature_lengths", - torch.empty((0, ))) +def create_qwen2_5_omni_thinker_field_factory( + spatial_merge_size: int +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, + MultiModalFieldConfig]]: - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, + torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", + torch.empty((0, ))) - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) - return dict( - input_audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_feature_lengths, dim=1), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - audio_feature_lengths=MultiModalFieldConfig.batched("audio"), - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - second_per_grid_ts=MultiModalFieldConfig.batched("video"), - ) + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) + + num_videos = len(video_grid_sizes) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + use_audio_in_video=MultiModalFieldConfig.shared( + "video", num_videos), + ) + + return _qwen2_5_omni_thinker_field_config class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(self._spatial_merge_size, *args, **kwargs) + def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -120,7 +141,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): required_fields={ "input_audio_features", "audio_feature_lengths" }, - fields_factory=_qwen2_5_omni_thinker_field_config, + fields_factory=create_qwen2_5_omni_thinker_field_factory( + self._spatial_merge_size), ) return super()._parse_audio_data(data) @@ -210,6 +232,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor( def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return Qwen2_5OmniThinkerMultiModalDataParser( + spatial_merge_size=self.info.get_hf_config( + ).vision_config.spatial_merge_size, target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -246,6 +270,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor( if ('audio_feature_lengths' not in hf_inputs and feature_attention_mask is not None): hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + + video_second_per_grid = hf_inputs.get("video_second_per_grid", None) + if video_second_per_grid is not None: + hf_inputs["second_per_grid_ts"] = video_second_per_grid + + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) + hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video) + return hf_inputs def _get_mm_fields_config( @@ -253,38 +285,32 @@ class Qwen2_5OmniThinkerMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2_5_omni_thinker_field_config(hf_inputs) + return create_qwen2_5_omni_thinker_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - use_audio_in_video = hf_processor_mm_kwargs.get( - "use_audio_in_video", False) + use_audio_in_video = (all( + item["use_audio_in_video"].data + for item in mm_kwargs["video"]) if "video" in mm_kwargs else False) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders( mm_placeholders, @@ -301,7 +327,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders( mm_placeholders, @@ -311,16 +336,13 @@ class Qwen2_5OmniThinkerMultiModalProcessor( tokenizer = self.info.get_tokenizer() prompt = decode_tokens(tokenizer, prompt_ids) - if use_audio_in_video: - mm_kwargs["use_audio_in_video"] = True - return prompt_ids, prompt, mm_placeholders def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() @@ -335,8 +357,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor( image_token_id = vocab[image_token] video_token_id = vocab[video_token] - audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths") - feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + out_mm_data = out_mm_kwargs.get_data() + audio_feature_lengths = out_mm_data.get("audio_feature_lengths") + feature_attention_mask = out_mm_data.get("feature_attention_mask") if audio_feature_lengths is None and feature_attention_mask is None: audio_output_lengths = [] elif audio_feature_lengths is not None: @@ -366,7 +389,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( return [audio_token_id] * num_features def get_replacement_qwen2_vision(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) merge_length = image_processor.merge_size**2 @@ -382,7 +405,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( audio_num_features = audio_output_lengths[audio_in_video_item_idx + item_idx] - video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + video_grid_thw = out_mm_data["video_grid_thw"][item_idx] audio_in_video_item_idx += 1 @@ -511,7 +534,7 @@ class Qwen2_5OmniConditionalGenerationMixin: return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]: input_audio_features = kwargs.pop('input_audio_features', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) @@ -525,9 +548,10 @@ class Qwen2_5OmniConditionalGenerationMixin: if not isinstance(input_audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio input features. " f"Got type: {type(input_audio_features)}") - return Qwen2AudioInputs(input_features=input_audio_features, - audio_feature_lengths=audio_feature_lengths, - feature_attention_mask=feature_attention_mask) + return Qwen2AudioFeatureInputs( + input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask) def _parse_and_validate_image_input( self, @@ -607,7 +631,7 @@ class Qwen2_5OmniConditionalGenerationMixin: def _process_audio_input( self, - audio_input: Qwen2AudioInputs, + audio_input: Qwen2AudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 5bcbcc4f0e37b..b528083b7c9cc 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -45,10 +45,14 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm +# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + MergedReplicatedLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) +# yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -57,6 +61,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -130,7 +135,7 @@ class Qwen2_5_VLVideoPixelInputs(TypedDict): second_per_grid_ts: torch.Tensor """ - The video time interval (in seconds) for each grid along the temporal + The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. """ @@ -170,19 +175,25 @@ class Qwen2_5_VisionMLP(nn.Module): bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( + cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else + MergedColumnParallelLinear) + self.gate_up_proj = cls_gate_up_proj( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + + cls_down_proj = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.down_proj = cls_down_proj(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -220,28 +231,42 @@ class Qwen2_5_VisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = (1 if use_data_parallel else + parallel_state.get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + if use_data_parallel: + self.qkv = ReplicatedLinear(embed_dim, + self.hidden_size_per_attention_head * + 3 * num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + + else: + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + + cls_proj = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.proj = cls_proj(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -302,8 +327,6 @@ class Qwen2_5_VisionAttention(nn.Module): k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: @@ -370,23 +393,27 @@ class Qwen2_5_VisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) - self.attn = Qwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = Qwen2_5_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) def forward( self, @@ -445,24 +472,30 @@ class Qwen2_5_VisionPatchMerger(nn.Module): spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) + + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) self.mlp = nn.ModuleList([ - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + cls_fc1(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), nn.GELU(), - RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), + cls_fc2(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -514,6 +547,7 @@ class Qwen2_5_VisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -523,6 +557,8 @@ class Qwen2_5_VisionTransformer(nn.Module): depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.out_hidden_size # args for get_window_index_thw self.window_size = vision_config.window_size @@ -550,7 +586,8 @@ class Qwen2_5_VisionTransformer(nn.Module): vision_config.hidden_act), norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(depth) ]) self.merger = Qwen2_5_VisionPatchMerger( @@ -560,6 +597,7 @@ class Qwen2_5_VisionTransformer(nn.Module): spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -767,7 +805,6 @@ class Qwen2_5_VisionTransformer(nn.Module): if weight_name not in name: continue name = name.replace(weight_name, param_name) - param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -815,6 +852,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -826,6 +868,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.": "language_model.model.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -840,6 +884,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config @@ -851,6 +896,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, quant_config=self._maybe_ignore_quant_config( self.quant_config), prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) else: self.visual = None @@ -973,7 +1019,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list) + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync @@ -995,8 +1047,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list) + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 3ef55cd704cf0..86b4a9a018c76 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -36,9 +36,11 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, +from vllm.multimodal.inputs import (AudioItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, + ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -52,7 +54,8 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model, # # === Audio Inputs === # -class Qwen2AudioInputs(TypedDict): +class Qwen2AudioFeatureInputs(TypedDict): + type: Literal["audio_features"] input_features: torch.Tensor """Shape: `(num_audios, num_mel_bins, 3000)`""" @@ -60,6 +63,16 @@ class Qwen2AudioInputs(TypedDict): """Shape: `(num_audios, 3000)`""" +class Qwen2AudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + audio_embeds: list[torch.Tensor] + """Shape: `(num_audio_features, hidden_size)` + `hidden_size` must match the hidden size of language model backbone. + """ + + +Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] + # === Audio Encoder === # @@ -128,12 +141,38 @@ class Qwen2AudioDummyInputsBuilder( } +def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]): + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + ) + + +class Qwen2AudioMultiModalDataParser(MultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_qwen2audio_field_config, + ) + + return super()._parse_audio_data(data) + + class Qwen2AudioMultiModalProcessor( BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + return Qwen2AudioMultiModalDataParser( + target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, @@ -173,17 +212,15 @@ class Qwen2AudioMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - input_features=MultiModalFieldConfig.batched("audio"), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - ) + return _qwen2audio_field_config(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -199,7 +236,8 @@ class Qwen2AudioMultiModalProcessor( audio_bos_id = vocab[audio_bos_token] audio_eos_id = vocab[audio_eos_token] - feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + out_mm_data = out_mm_kwargs.get_data() + feature_attention_mask = out_mm_data.get("feature_attention_mask") if feature_attention_mask is None: audio_output_lengths = [] else: @@ -210,7 +248,15 @@ class Qwen2AudioMultiModalProcessor( audio_output_lengths = audio_output_lens.tolist() def get_replacement_qwen2_audio(item_idx: int): - num_features = audio_output_lengths[item_idx] + + if audio_output_lengths: + num_features = audio_output_lengths[item_idx] + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + assert len(audio_embeds.shape + ) == 2, "audio_embeds must be a 2D tensor" + num_features = audio_embeds.shape[0] + if num_features == 0: audios = mm_items.get_items("audio", AudioProcessorItems) audio_len = audios.get_audio_length(item_idx) @@ -285,21 +331,39 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Qwen2AudioInputs]: input_features = kwargs.pop('input_features', None) + audio_embeds = kwargs.pop('audio_embeds', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) - if input_features is None: - return None - input_features = self._validate_and_reshape_mm_tensor( - input_features, 'input_features') - feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') - if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_features)}") - return Qwen2AudioInputs(input_features=input_features, - feature_attention_mask=feature_attention_mask) - def _process_audio_input(self, - audio_input: Qwen2AudioInputs) -> torch.Tensor: + if input_features is None and audio_embeds is None: + return None + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + audio_embeds = self._validate_and_reshape_mm_tensor( + audio_embeds, "audio_embeds") + return Qwen2AudioEmbeddingInputs(type="audio_embeds", + audio_embeds=audio_embeds) + + if input_features is not None: + input_features = self._validate_and_reshape_mm_tensor( + input_features, 'input_features') + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + return Qwen2AudioFeatureInputs( + type="audio_features", + input_features=input_features, + feature_attention_mask=feature_attention_mask) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: Qwen2AudioInputs + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if audio_input["type"] == "audio_embeds": + audio_embeds = audio_input["audio_embeds"] + return tuple(audio_embeds) input_features = audio_input["input_features"] feature_attention_mask = audio_input["feature_attention_mask"] diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index e0a30e04c602a..421b43563bade 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -18,7 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f2d438b3850b8..ae7a8d8d7a5b9 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -58,7 +58,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) @@ -329,8 +329,6 @@ class Qwen2VisionAttention(nn.Module): k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: @@ -701,29 +699,46 @@ class Qwen2VisionTransformer(nn.Module): return loaded_params -def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) +def _create_qwen2vl_field_factory( + spatial_merge_size: int +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) + def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - ) + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + return _qwen2vl_field_config class Qwen2VLMultiModalDataParser(MultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(*args, **kwargs) + def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -733,7 +748,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="image", required_fields={"image_embeds", "image_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_image_data(data) @@ -747,7 +763,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="video", required_fields={"video_embeds", "video_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_video_data(data) @@ -969,13 +986,14 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): def _get_data_parser(self) -> MultiModalDataParser: - return Qwen2VLMultiModalDataParser() + return Qwen2VLMultiModalDataParser( + self.info.get_hf_config().vision_config.spatial_merge_size) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -991,7 +1009,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length @@ -1011,7 +1030,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, @@ -1225,7 +1246,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] - return None # The result multimodal_embeddings is tuple of tensors, with each # tensor correspoending to a multimodal data item (image or video). diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 2060206633702..dddb47048a1fc 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -304,10 +304,10 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 61b16b6a1d2d8..8498f61b35fdd 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -45,6 +45,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -121,11 +124,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size @@ -139,18 +142,27 @@ class Qwen3MoeSparseMoeBlock(nn.Module): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, + reduce_results=True, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -163,10 +175,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) - return final_hidden_states.view(orig_shape) @@ -367,7 +375,8 @@ class Qwen3MoeModel(nn.Module): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config enable_eplb = parallel_config.enable_eplb - self.num_redundant_experts = parallel_config.num_redundant_experts + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -685,4 +694,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() \ No newline at end of file + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 4c3fd6b5156d0..2950ca664a98f 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -627,7 +627,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() special_tokens: dict[str, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b817615b43564..80eac78cdfadb 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -25,11 +25,14 @@ from vllm.logger import init_logger from vllm.transformers_utils.dynamic_module import ( try_get_class_from_dynamic_module) -from .interfaces import (get_default_pooling_type, has_inner_state, has_noops, - is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_multimodal_raw_input, - supports_pp, supports_transcription, supports_v0_only) -from .interfaces_base import is_pooling_model, is_text_generation_model +from .interfaces import (has_inner_state, has_noops, is_attention_free, + is_hybrid, supports_cross_encoding, + supports_multimodal, + supports_multimodal_encoder_tp_data, + supports_multimodal_raw_input, supports_pp, + supports_transcription, supports_v0_only) +from .interfaces_base import (get_default_pooling_type, is_pooling_model, + is_text_generation_model) logger = init_logger(__name__) @@ -93,6 +96,7 @@ _TEXT_GENERATION_MODELS = { "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501 # For decapoda-research/llama-* @@ -129,6 +133,7 @@ _TEXT_GENERATION_MODELS = { "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), @@ -141,6 +146,7 @@ _TEXT_GENERATION_MODELS = { # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), + "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"), } _EMBEDDING_MODELS = { @@ -203,6 +209,7 @@ _MULTIMODAL_MODELS = { "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501 "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501 @@ -216,6 +223,7 @@ _MULTIMODAL_MODELS = { "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), + "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), @@ -230,6 +238,7 @@ _MULTIMODAL_MODELS = { "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), + "Ovis2_5": ("ovis2_5", "Ovis2_5"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), @@ -247,6 +256,7 @@ _MULTIMODAL_MODELS = { "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] + "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 @@ -262,7 +272,9 @@ _SPECULATIVE_DECODING_MODELS = { "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), # Temporarily disabled. @@ -314,6 +326,7 @@ class _ModelInfo: supports_cross_encoding: bool supports_multimodal: bool supports_multimodal_raw_input: bool + supports_multimodal_encoder_tp_data: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -333,6 +346,8 @@ class _ModelInfo: supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_multimodal_raw_input=supports_multimodal_raw_input(model), + supports_multimodal_encoder_tp_data= + supports_multimodal_encoder_tp_data(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 32a4a2c9a2694..2bfa51162910b 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,6 @@ from torch import nn from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -23,7 +22,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type class RobertaEmbedding(nn.Module): @@ -100,7 +100,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id def forward( self, @@ -178,7 +178,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, @@ -233,58 +233,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): intermediate_tensors=intermediate_tensors) -# Adapted from transformers -def create_position_ids_from_input_ids(input_ids, - padding_idx, - past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. - Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: torch.Tensor x: - - Returns: torch.Tensor - """ - # The series of casts and type-conversions here are carefully - # balanced to both work with ONNX export and XLA. - mask = input_ids.ne(padding_idx).int() - - incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + - past_key_values_length) * mask - - return incremental_indices.long() + padding_idx - - def replace_roberta_positions(input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int) -> None: - - seq_lens: Optional[torch.Tensor] = None - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: # can be None during warmup - if isinstance(attn_metadata, dict): - attn_metadata = next(iter(attn_metadata.values())) - # TODO: remove "seq_lens_tensor" after V0 is removed - seq_lens = getattr(attn_metadata, "seq_lens_tensor", - getattr(attn_metadata, "seq_lens", None)) - - if seq_lens is not None: - assert isinstance(seq_lens, torch.Tensor) - - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - token_list = torch.split(input_ids[:torch.sum(seq_lens)], - seq_lens.tolist()) - - offset = 0 - for tokens in token_list: - length = tokens.shape[0] - position_ids[offset:offset+length] = \ - create_position_ids_from_input_ids(tokens, padding_idx) - offset = offset + length + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + # vllm does not use padding tokens, let's make things simpler + position_ids += padding_idx + 1 diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py new file mode 100644 index 0000000000000..efdb010046634 --- /dev/null +++ b/vllm/model_executor/models/rvl.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Mapping + +import torch +import torch.nn as nn +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict + +from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor, + LlavaNextProcessingInfo) +from .llava_onevision import LlavaOnevisionForConditionalGeneration +from .utils import WeightsMapper + + +class RVLProcessingInfo(LlavaNextProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(**kwargs) + + +class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + image_token = "<image>" + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class RVLMultiModalProjector(nn.Module): + + def __init__(self, config): + super().__init__() + self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, + eps=1e-06) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + self.act = GELUActivation() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + def forward(self, image_feature: torch.Tensor) -> torch.Tensor: + image_feature = self.pre_norm(image_feature) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextMultiModalProcessor, + info=RVLProcessingInfo, + dummy_inputs=RVLDummyInputsBuilder, +) +class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers + # v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + self.multi_modal_projector = RVLMultiModalProjector(config) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py new file mode 100644 index 0000000000000..34a87a6a69a39 --- /dev/null +++ b/vllm/model_executor/models/seed_oss.py @@ -0,0 +1,487 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Seed team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only SeedOss model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig as SeedOssConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class SeedOssMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SeedOssAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + self.head_dim = head_dim + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class SeedOssDecoderLayer(nn.Module): + + def __init__( + self, + config: SeedOssConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, SeedOss uses causal attention as it is a + # decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = SeedOssAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = SeedOssMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class SeedOssModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + assert config.max_window_layers == config.num_hidden_layers, ( + "Sliding window for some but all layers is not supported. " + "This model uses sliding window but `max_window_layers` = {} " + "is less than `num_hidden_layers` = {}. Please open an issue " + "to discuss this feature.".format( + config.max_window_layers, + config.num_hidden_layers, + )) + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to SeedDecoderLayer + decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = SeedOssModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py new file mode 100644 index 0000000000000..c6244fb3b3e6a --- /dev/null +++ b/vllm/model_executor/models/siglip2navit.py @@ -0,0 +1,688 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F +from transformers import Siglip2VisionConfig +from transformers.configuration_utils import PretrainedConfig + +from vllm.config import QuantizationConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import _Backend + +from .vision import get_vit_attn_backend + + +class VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Siglip2VisionEmbeddings(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + self.image_size = config.image_size + self.num_patches = config.num_patches + self.preserve_original_pe = config.preserve_original_pe + self.hidden_stride = config.hidden_stride + + # siglip2 naflex + if self.num_patches > 0: + self.patch_embedding = ReplicatedLinear( + input_size=config.num_channels * self.patch_size * + self.patch_size, + output_size=self.embed_dim, + return_bias=False, + ) + if self.preserve_original_pe: + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + else: + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + if self.preserve_original_pe: + self.num_patches = (self.image_size // self.patch_size)**2 + self.position_embedding_size = (self.image_size // + self.patch_size) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape ( + num_patches, + num_channels * temporal_patch_size * patch_size * patch_size + ) + grid_thws: (`torch.LongTensor`): + grid shape (num_patches, 3) + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + if isinstance(self.patch_embedding, LinearBase): + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + elif isinstance(self.patch_embedding, nn.Conv2d): + pixel_values = pixel_values.view( + -1, self.config.num_channels * self.config.temporal_patch_size, + self.patch_size, self.patch_size) + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + patch_embeds = patch_embeds.reshape(-1, self.embed_dim) + + if self.preserve_original_pe: + assert grid_thws is not None + pos_embed_new = torch.zeros_like(patch_embeds) + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, + -1).unsqueeze(0).permute(0, 3, 1, 2) + cnt = 0 + for t, h, w in grid_thws: + volume = t * h * w + pe = F.interpolate(positional_embeddings, + size=(h, w), + mode='bicubic', + align_corners=False) + pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1) + pe = pe[0].repeat(t, 1) + pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, + w // self.hidden_stride, self.hidden_stride, + -1) + pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1) + pos_embed_new[cnt:cnt + volume] = pe + cnt += volume + patch_embeds = patch_embeds + pos_embed_new + + return patch_embeds + + +# copy from flash_attn/layers/rotary.py +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_flash_attn_backend: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + if is_flash_attn_backend: + from flash_attn.layers.rotary import apply_rotary_emb + apply_rotary_emb_func = apply_rotary_emb + else: + apply_rotary_emb_func = apply_rotary_emb_torch + q_embed = apply_rotary_emb_func(q.float(), cos.float(), + sin.float()).type_as(q) + k_embed = apply_rotary_emb_func(k.float(), cos.float(), + sin.float()).type_as(k) + return q_embed, k_embed + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + # TODO(Isotr0py): Enable data parallel after we support + # disabling TP on parallel linear layer + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.use_rope = config.use_rope + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, + _Backend.ROCM_AITER_FA + }: + self.attn_backend = _Backend.TORCH_SDPA + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, + torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + seq_length, embed_dim = hidden_states.shape + + qkv_states, _ = self.qkv_proj(hidden_states) + queries, keys, values = qkv_states.chunk(3, dim=-1) + + queries = queries.view(seq_length, self.num_heads_per_partition, + self.head_dim) + keys = keys.view(seq_length, self.num_heads_per_partition, + self.head_dim) + values = values.view(seq_length, self.num_heads_per_partition, + self.head_dim) + + if self.use_rope: + cos, sin = position_embeddings + queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0), + keys.unsqueeze(0), cos, sin, + self.is_flash_attn_backend) + queries = queries.squeeze(0) + keys = keys.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + if self.is_flash_attn_backend: + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + attn_output = flash_attn_varlen_func( + queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, + max_seqlen).reshape(seq_length, -1) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + batch_size = cu_seqlens.shape[0] - 1 + outputs = [] + cu = cu_seqlens.tolist() + for i in range(batch_size): + start_idx = cu[i] + end_idx = cu[i + 1] + + # Each sequence is processed independently. + q_i = queries[start_idx:end_idx].unsqueeze(0) + k_i = keys[start_idx:end_idx].unsqueeze(0) + v_i = values[start_idx:end_idx].unsqueeze(0) + + # (1, seq_len, num_heads, head_dim) -> + # (1, num_heads, seq_len, head_dim) + q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)] + + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) + output_i = output_i.transpose(1, 2).reshape( + end_idx - start_idx, -1) + outputs.append(output_i) + + attn_output = torch.cat(outputs, dim=0) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class Siglip2MLP(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + # TODO(Isotr0py): Enable data parallel after we support + # disabling TP on parallel linear layer + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention(config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` + self attention layers. Each layer is a [`Siglip2EncoderLayer`]. + + Args: + config: PretrainedConfig + """ + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Siglip2EncoderLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + use_data_parallel=use_data_parallel) + for idx in range(config.num_hidden_layers) + ]) + + self.rotary_pos_emb = VisionRotaryEmbedding( + config.hidden_size // config.num_attention_heads // 2) + self.patch_size = config.patch_size + self.hidden_stride = config.hidden_stride + self.window_size = config.window_size + self.spatial_merge_unit = config.hidden_stride * config.hidden_stride + if config.fullatt_block_indexes is None: + self.fullatt_block_indexes = None + else: + self.fullatt_block_indexes = [ + int(i) for i in config.fullatt_block_indexes.split('|') + ] + + # copied from qwen2.5_vl + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + # patch (after merge) number in each window + vit_merger_window_size = (self.window_size // self.hidden_stride // + self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.hidden_stride, # number of patch after merge + grid_w // self.hidden_stride, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + inputs_embeds: torch.Tensor, + grid_thws: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if + you want more control over how to convert `input_ids` indices + into associated vectors than the model's internal embedding + lookup matrix. + grid_thws (`torch.LongTensor`): + grid shape (num_patches, 3) + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of + a plain tuple. + """ + rotary_pos_emb = self.rot_pos_emb(grid_thws) + window_index, cu_window_seqlens = self.get_window_index(grid_thws) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=inputs_embeds.device, + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = inputs_embeds.size() + inputs_embeds = inputs_embeds.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + inputs_embeds = inputs_embeds[window_index, :, :] + inputs_embeds = inputs_embeds.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have + # same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 + # for more information + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + reverse_indices = torch.argsort(window_index) + + hidden_states = inputs_embeds + for index, block in enumerate(self.layers): + if (not self.fullatt_block_indexes + or index in self.fullatt_block_indexes): + cu_seqlens_tmp = cu_seqlens + else: + cu_seqlens_tmp = cu_window_seqlens + hidden_states = block(hidden_states, cu_seqlens_tmp, + position_embeddings) + + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) + + return hidden_states + + +class Siglip2VisionTransformer(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + ) -> torch.Tensor: + r""" + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) + of the input images. + """ + hidden_states = self.embeddings(pixel_values, grid_thws) + + last_hidden_state = self.encoder(hidden_states, grid_thws) + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Siglip2NavitModel(torch.nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + + self.vision_model = Siglip2VisionTransformer( + config, + quant_config=quant_config, + prefix=f"{prefix}.vision_model", + use_data_parallel=use_data_parallel) + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + grid_thws=grid_thws, + ) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index c76aabcd27ccb..9857ccdcbe2d4 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,7 +8,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -26,7 +26,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -35,6 +35,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, @@ -48,27 +49,42 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) -class SkyworkR1VImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_flat: torch.Tensor +class SkyworkR1VImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width + - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + pixel_values_flat: Annotated[ + torch.Tensor, + TensorShape("bnp", 3, "h", "w"), + ] + + num_patches: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] -class SkyworkR1VImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class SkyworkR1VImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - ni: Number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("ni", "ifs", "hs"), + ] SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, @@ -552,18 +568,19 @@ class SkyworkR1VMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -730,26 +747,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) @@ -787,10 +784,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): return SkyworkR1VImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, - ) + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size, + }) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index f1f38c01b7848..f379d2c15fb6c 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -520,20 +520,18 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_placeholder_token_id = hf_processor.image_token_id - batch_num_patches = out_mm_kwargs["num_patches"].tolist() def get_replacement_step1o(item_idx: int): - img_out = out_mm_kwargs.get_item("image", item_idx) - num_patches = batch_num_patches[item_idx] + out_item = out_mm_kwargs["image"][item_idx] + num_patches = int(out_item["num_patches"].data) if num_patches > 0: - patch_newline_mask = img_out["patch_newline_mask"].data.tolist( - ) + patch_newline_mask = out_item["patch_newline_mask"].data image_repl_ids = hf_processor._get_image_repl_features( - 1, num_patches, patch_newline_mask)[1] + 1, num_patches, patch_newline_mask.tolist())[1] else: image_repl_ids = hf_processor._get_image_repl_features( 1, 0, None)[1] @@ -869,6 +867,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, "lm_head.": "language_model.lm_head.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -884,8 +884,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Step3VisionTransformer( diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py new file mode 100644 index 0000000000000..30b441f5b4df0 --- /dev/null +++ b/vllm/model_executor/models/swin.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import SwinConfig +from transformers.models.swin.modeling_swin import SwinEmbeddings +from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer +from transformers.models.swin.modeling_swin import SwinPatchMerging +from transformers.pytorch_utils import meshgrid + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class SwinSelfAttention(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of " + f"attention heads ({num_heads})") + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = (window_size if isinstance(window_size, Iterable) + else (window_size, window_size)) + self.scale = self.attention_head_size**-0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, + None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + + self.relative_position_index = nn.Parameter(relative_position_index, + requires_grad=False) + + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.attention_head_size, + total_num_heads=self.num_attention_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, ...]: + batch_size, dim, num_channels = hidden_states.shape + + qkv_output, _ = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1) + + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + query_layer = self.transpose_for_scores(query_layer) + + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, + dim).expand(batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + \ + attention_mask_expanded.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view(-1, + self.num_attention_heads, + dim, dim) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0., + ) + attention_probs = None + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + return outputs + + +class SwinSelfOutput(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.self = SwinSelfAttention(config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self") + self.output = SwinSelfOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, + output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, ) + self_outputs[1:] + return outputs + + +class SwinIntermediate(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = ColumnParallelLinear(dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = RowParallelLinear(int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + return hidden_states + + +class SwinLayer(HFSwinLayer): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path_rate, + shift_size=shift_size, + ) + + self.attention = SwinAttention(config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.intermediate = SwinIntermediate(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + self.output = SwinOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + + +class SwinStage(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Optional[SwinPatchMerging] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList([ + SwinLayer(config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if + (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, + dim=dim, + norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + + 1) // 2 + output_dimensions = (height, width, height_downsampled, + width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, + input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, + output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + + def __init__( + self, + config: SwinConfig, + grid_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [ + x.item() for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + self.layers = nn.ModuleList([ + SwinStage(config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=(grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx)), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[sum(config.depths[:layer_idx] + ):sum(config.depths[:layer_idx + 1])], + downsample=SwinPatchMerging if + (layer_idx < self.num_layers - 1) else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(self.num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + return hidden_states + + +class SwinModel(nn.Module): + config_class: SwinConfig + + def __init__( + self, + config: SwinConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2**(self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor]: + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv", "query", "q"), + ("qkv", "key", "k"), + ("qkv", "value", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index c8709d866b1e7..c66867315e553 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union, cast) import torch @@ -25,15 +25,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llava import LlavaDummyInputsBuilder from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -43,14 +45,28 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .vision import VisionEncoderInfo, get_vision_encoder_info -class TarsierImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class TarsierImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class TarsierImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor +class TarsierImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] TarsierImageInputs = Union[TarsierImagePixelInputs, @@ -275,7 +291,7 @@ class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index # The <IMAGE> token ID @@ -317,7 +333,7 @@ def _build_tarsier_hf_processor( info: _I_Tarsier, dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, TarsierProcessingInfo): return TarsierMultiModalProcessor( @@ -432,18 +448,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) # Assuming 3 channels - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[TarsierImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -459,8 +463,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, return TarsierImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values, concat=True), ) if image_embeds is not None: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 4ec2b683fc33a..fc242d1adafd0 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -41,7 +41,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems @@ -237,7 +237,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ): """ Given the original multi-modal items for this modality @@ -310,7 +310,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -372,7 +371,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_tokens_per_modality["num_image_patches"] ) if "num_image_patches" in mm_tokens_per_modality else None processed_data['num_image_patches'] = num_image_patches - mm_kwargs = MultiModalKwargs.from_hf_inputs( + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( processed_data, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, num_image_patches), @@ -694,10 +693,28 @@ class TransformersForCausalLM(TransformersBase): return logits +def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: + """Flatten until a list of tensors can be concatenated then do concat""" + + def _can_concat(x: list[torch.Tensor]): + return len(set(map(lambda _x: _x.shape[1:], x))) == 1 + + if _can_concat(x): + return torch.concat(x) + return flatten_and_concat(flatten_bn(x)) + + @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder) +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) # set `positions` to last dim to support Qwen-mrope class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): # Backwards compatibility for prev released models. State dicts back then # had different formats and cannot be loaded with `AutoModel` mapping as is @@ -766,8 +783,7 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): if isinstance(pixel_values, torch.Tensor): pixel_values = flatten_bn(pixel_values).to(self.dtype) elif is_list_of(pixel_values, torch.Tensor): - pixel_values = flatten_bn(flatten_bn(pixel_values), - concat=True).to(self.dtype) + pixel_values = flatten_and_concat(pixel_values).to(self.dtype) else: raise ValueError( f"Unsupported pixel_values type {type(pixel_values)}. " diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index bef34c1be49fe..f91c4ddb6e834 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -23,7 +23,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -194,7 +194,7 @@ class UltravoxMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -203,7 +203,8 @@ class UltravoxMultiModalProcessor( # Each audio can be split into multiple chunks. # chunks_start_idx[i] indicates the start index of the chunks # belonging to the i-th audio. - num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0)) + out_mm_data = out_mm_kwargs.get_data() + num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0)) chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, dim=0, dtype=torch.int32) @@ -213,7 +214,7 @@ class UltravoxMultiModalProcessor( def get_replacement_ultravox(item_idx: int): start = chunks_start_idx[item_idx] end = chunks_start_idx[item_idx + 1] - audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum() + audio_token_len = out_mm_data["audio_token_len"][start:end].sum() return [replacement_id] * int(audio_token_len) # type: ignore return [ diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 6c27fedc61b17..11e098f1d7bdb 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -508,7 +508,9 @@ def merge_multimodal_embeddings( """ if isinstance(placeholder_token_id, list): placeholder_token_id = torch.tensor(placeholder_token_id, - device=input_ids.device) + pin_memory=True).to( + device=input_ids.device, + non_blocking=True) return _merge_multimodal_embeddings( inputs_embeds, torch.isin(input_ids, placeholder_token_id), diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 6b06c0ac6683f..77f11a691e080 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -31,11 +31,12 @@ from vllm.model_executor.models.whisper import WhisperEncoder from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -259,7 +260,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -287,20 +288,16 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: - prompt_ids, mm_kwargs, mm_hashes, _ = super( - )._cached_apply_hf_processor( + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, mm_hashes, True + return prompt_ids, mm_info, True def _get_data_parser(self) -> MultiModalDataParser: sampling_rate = self.info.get_hf_processor().sampling_rate diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ca02ecd828ba3..16bbe2f2010a1 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -33,7 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, @@ -728,7 +728,7 @@ class WhisperMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: num_tokens = self.info.get_num_audio_tokens() return [ diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py index e6f1ca61dd291..3209879193453 100644 --- a/vllm/model_executor/pooling_metadata.py +++ b/vllm/model_executor/pooling_metadata.py @@ -2,12 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any +from typing import Any, Optional import torch from vllm.pooling_params import PoolingParams from vllm.utils import is_pin_memory_available +from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor class PoolingMetadata: @@ -23,14 +24,15 @@ class PoolingMetadata: """ def __init__( - self, - seq_groups: list[tuple[list[int], PoolingParams]], - seq_data: dict[int, Any], # Specific data related to sequences - prompt_lens: list[int], - ) -> None: + self, + seq_groups: list[tuple[list[int], PoolingParams]], + seq_data: dict[int, Any], # Specific data related to sequences + prompt_lens: list[int], + pooling_cursor: Optional[PoolingCursor] = None) -> None: self.seq_groups = seq_groups self.seq_data = seq_data self.prompt_lens = prompt_lens + self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor def __repr__(self) -> str: return ("PoolingMetadata(" @@ -43,8 +45,17 @@ class PoolingMetadata: seq_groups=self.seq_groups[indices], seq_data=dict(list(self.seq_data.items())[indices]), prompt_lens=self.prompt_lens[indices], + pooling_cursor=None + if self.pooling_cursor is None else self.pooling_cursor[indices], ) + def build_pooling_cursor(self, num_scheduled_tokens: list[int], + device: torch.device): + prompt_lens = torch.tensor(self.prompt_lens, device="cpu") + self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, + prompt_lens, + device=device) + @dataclass class PoolingTensors: diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 2ef9f1ccc02be..69eed22741446 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -4,7 +4,8 @@ from .base import MultiModalPlaceholderMap from .hasher import MultiModalHashDict, MultiModalHasher from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict, NestedTensors) + MultiModalKwargsItems, MultiModalPlaceholderDict, + NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -25,6 +26,7 @@ __all__ = [ "MultiModalHashDict", "MultiModalHasher", "MultiModalKwargs", + "MultiModalKwargsItems", "MultiModalPlaceholderDict", "MultiModalPlaceholderMap", "NestedTensors", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 7188ed14c5735..ef8f1b2e17b47 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -99,7 +99,7 @@ class MultiModalPlaceholderMap: seq_mm_placeholders = seq_group.multi_modal_placeholders if not seq_mm_data or not seq_mm_placeholders: - return MultiModalKwargs({}), {} + return MultiModalKwargs(), {} placeholder_maps = dict[str, MultiModalPlaceholderMap]() diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 6074a4d54f223..0e81cb6d4d190 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -1,35 +1,82 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import sys -from collections.abc import Mapping -from dataclasses import dataclass -from typing import TypeVar, Union +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union import torch +from typing_extensions import TypeAlias, override from vllm.logger import init_logger from vllm.utils import GiB_bytes, LRUCache from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves -from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors +from .inputs import (MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, MultiModalKwargsItems, + NestedTensors) + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + + from .processing import ResolvedPromptUpdate + from .registry import MultiModalRegistry logger = init_logger(__name__) -@dataclass -class MultiModalCacheItemMetadata: - size: int +class MultiModalProcessorCacheItem: + """ + The data to store inside `MultiModalProcessorOnlyCache`. - @classmethod - def wraps(cls, value: "MultiModalCacheValue"): - return cls(size=MultiModalCache.get_item_size(value)) + Args: + item: The processed tensor data corresponding to a multi-modal item. + prompt_updates: The prompt updates corresponding to `item`. + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item = item + self.prompt_updates = prompt_updates + + +class MultiModalProcessorCacheItemMetadata: + """ + The metadata to store inside `MultiModalProcessorSenderCache`. + + Args: + item: The processed tensor data corresponding to a multi-modal item. + Since P1 already stores the tensor data, we only store its size + metadata in P0 to reduce memory usage. The size metadata is still + needed to keep the same cache eviction policy as P0. + prompt_updates: The prompt updates corresponding to `item`. + This needs to stay on P0 because for some models, they are + dependent on the processed tensor data (cached on P1). + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item_size = MultiModalCache.get_item_size(item) + self.prompt_updates = prompt_updates MultiModalCacheValue = Union[ - MultiModalKwargs, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + MultiModalKwargsItems, MultiModalKwargsItem, + MultiModalKwargs, Mapping[str, NestedTensors], - MultiModalCacheItemMetadata, ] _V = TypeVar("_V", bound=MultiModalCacheValue) @@ -44,22 +91,26 @@ class MultiModalCache: *, debug: bool = False, ) -> int: - # MultiModalKwargs is not a subclass of dict - if isinstance(leaf, MultiModalKwargs): - return cls.get_item_size(leaf.data, debug=debug) + if isinstance(leaf, MultiModalProcessorCacheItem): + return cls.get_leaf_size(leaf.item) + if isinstance(leaf, MultiModalProcessorCacheItemMetadata): + return leaf.item_size - # MultiModalKwargsItem is not a subclass of dict + # These are not subclasses of dict + if isinstance(leaf, MultiModalKwargsItems): + return cls.get_item_size(leaf.data) # type: ignore if isinstance(leaf, MultiModalKwargsItem): - leaf_data = {k: v.data for k, v in leaf.items()} - return cls.get_item_size(leaf_data, debug=debug) + return cls.get_item_size(leaf.data) # type: ignore + if isinstance(leaf, MultiModalKwargs): + return cls.get_item_size(leaf.data) # type: ignore + + if isinstance(leaf, MultiModalFieldElem): + return cls.get_item_size(leaf.data) # type: ignore # sys.getsizeof doesn't work for tensors if isinstance(leaf, torch.Tensor): return leaf.nbytes - if isinstance(leaf, MultiModalCacheItemMetadata): - return leaf.size - return sys.getsizeof(leaf) @classmethod @@ -93,3 +144,332 @@ class MultiModalCache: GiB_bytes * capacity_gb, getsizeof=lambda x: cls.get_item_size(x, debug=debug), ) + + +_I = TypeVar("_I", contravariant=True) +_O = TypeVar("_O", covariant=True) + + +class BaseMultiModalCache(ABC, Generic[_I, _O]): + """ + Abstract base class to read/write multi-modal items from cache. + + The idea of multi-modal caching is based on having a client and server + where the client executes in the frontend process (=P0) and + the server in the core process (=P1). The data flow is as follows: + + ``` + is_cached() x N get_and_update() + P0: From API -----------------> -----------------> To P1 + + get_and_update() + P1: From P0 -----------------> To model + ``` + + `is_cached()` can be called any number of times in P0. However, + `get_and_update()` must be called in P0 and P1 one after another + so that their cache eviction order remains the same. + + This ensures that the keys in P0 and P1 caches are mirrored, + allowing us to determine whether a key is cached in P1 by looking + up the P0 cache, without having to communicate with P1. + """ + + @abstractmethod + def get_and_update_item( + self, + mm_item: _I, + mm_hash: str, + ) -> _O: + """ + Possibly update a multi-modal item based on whether it is + in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_item: The multi-modal item to update. + mm_hash: The hash of `mm_item`. + + Returns: + The update multi-modal item. + """ + raise NotImplementedError + + def get_and_update( + self, + mm_items: Sequence[_I], + mm_hashes: list[str], + ) -> list[_O]: + """ + Possibly update a sequence of multi-modal items based on whether they + are in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_items: The multi-modal items to update. + mm_hashes: The hash of each item in `mm_items`. + + Returns: + A new list of updated multi-modal items. + """ + assert len(mm_items) == len(mm_hashes) + + return [ + self.get_and_update_item(mm_item, mm_hash) + for mm_item, mm_hash in zip(mm_items, mm_hashes) + ] + + @abstractmethod + def clear_cache(self) -> None: + """Clear the underlying cache.""" + raise NotImplementedError + + +MultiModalProcessorCacheInItem: TypeAlias = \ + Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]] + + +MultiModalProcessorCacheOutItem: TypeAlias = \ + tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]] + + +class BaseMultiModalProcessorCache( + BaseMultiModalCache[MultiModalProcessorCacheInItem, + MultiModalProcessorCacheOutItem]): + """The required interface for caches on P0.""" + + @abstractmethod + def is_cached_item(self, mm_hash: str) -> bool: + """ + Check whether a multi-modal item is + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hash: The hash of the item to check. + + Returns: + `True` if the item is cached, otherwise `False`. + """ + raise NotImplementedError + + def is_cached(self, mm_hashes: list[str]) -> list[bool]: + """ + Check whether a sequence of multi-modal items are + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hashes: The hash of each item to check. + + Returns: + For each item, `True` if the item is cached, otherwise `False`. + """ + return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + + +class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is disabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes + tensor data and metadata) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItem, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item.item, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the metadata of that item so + that the eviction policy remains the same as the cache on P1, + and return the input. + By only storing the metadata, we avoid keeping the data itself in + memory inside P0. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItemMetadata, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return None, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def _enable_processor_cache( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +) -> bool: + if not mm_registry.supports_multimodal_inputs(model_config): + return False + + mm_config = model_config.get_multimodal_config() + return mm_config.mm_processor_cache_gb > 0 + + +def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: + parallel_config = vllm_config.parallel_config + supports_ipc_cache = (parallel_config.data_parallel_size == 1 + or parallel_config.data_parallel_external_lb) + + return supports_ipc_cache + + +def processor_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalProcessorCache]: + """Return a `BaseMultiModalProcessorCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return MultiModalProcessorOnlyCache(model_config) + + return MultiModalProcessorSenderCache(model_config) + + +def processor_only_cache_from_config( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +): + """Return a `MultiModalProcessorOnlyCache`, if enabled.""" + if not _enable_processor_cache(model_config, mm_registry): + return None + + return MultiModalProcessorOnlyCache(model_config) + + +class BaseMultiModalReceiverCache( + BaseMultiModalCache[Optional[MultiModalKwargsItem], + MultiModalKwargsItem]): + """The required interface for caches on P1.""" + + +class MultiModalReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 when IPC caching is enabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes tensor + data) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalKwargsItem, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = mm_item + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalReceiverCache]: + """Return a `BaseMultiModalReceiverCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + return MultiModalReceiverCache(model_config) diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index c9ce1f0be5f88..3708dc7065ba1 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -43,7 +43,25 @@ class MultiModalHasher: return cls.item_to_bytes( "image", np.asarray(convert_image_mode(obj, "RGBA"))) if isinstance(obj, torch.Tensor): - return cls.item_to_bytes("tensor", obj.numpy()) + tensor_obj: torch.Tensor = obj.cpu() + tensor_dtype = tensor_obj.dtype + tensor_shape = tensor_obj.shape + + # NumPy does not support bfloat16. + # Workaround: View the tensor as a contiguous 1D array of bytes + if tensor_dtype == torch.bfloat16: + tensor_obj = tensor_obj.contiguous() + tensor_obj = tensor_obj.view( + (tensor_obj.numel(), )).view(torch.uint8) + + return cls.item_to_bytes( + "tensor", { + "original_dtype": str(tensor_dtype), + "original_shape": tuple(tensor_shape), + "data": tensor_obj.numpy(), + }) + + return cls.item_to_bytes("tensor", tensor_obj.numpy()) if isinstance(obj, np.ndarray): # If the array is non-contiguous, we need to copy it first arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 0bbac45c121b6..2c0ebaced67ef 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -4,14 +4,14 @@ from abc import ABC, abstractmethod from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, - Union, cast, final) +from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, + cast, final) import numpy as np -from typing_extensions import NotRequired, TypeAlias +from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated from vllm.utils import LazyLoader, full_groupby, is_list_of from vllm.utils.jsontree import JSONTree, json_map_leaves @@ -218,7 +218,7 @@ class MultiModalFieldElem: i.e. the name of the keyword argument to be passed to the model. """ - data: Optional[NestedTensors] + data: NestedTensors """ The tensor data of this field in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs], @@ -315,13 +315,8 @@ class BaseMultiModalField(ABC): if len(set(field_types)) > 1: raise ValueError(f"Cannot merge different {field_types=}") - validated_data = list[NestedTensors]() - for i, elem in enumerate(elems): - assert elem.data is not None, ( - f"Cannot merge with empty `elems[{i}]`") - validated_data.append(elem.data) - - return self._reduce_data(validated_data, pin_memory=pin_memory) + batch = [elem.data for elem in elems] + return self._reduce_data(batch, pin_memory=pin_memory) @dataclass(frozen=True) @@ -643,71 +638,49 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. """ + @staticmethod + def dummy(modality: str): + """Convenience class for testing.""" + mm_elem = MultiModalFieldElem( + modality=modality, + key="dummy", + data=torch.empty(1), + field=MultiModalSharedField(1), + ) + return MultiModalKwargsItem.from_elems([mm_elem]) + @staticmethod def from_elems(elems: Sequence[MultiModalFieldElem]): return MultiModalKwargsItem({elem.key: elem for elem in elems}) - def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None: + def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None: super().__init__(data) - modalities = {elem.modality for elem in self.data.values()} + modalities = {elem.modality for elem in self.values()} assert len(modalities) == 1, f"Found different modalities={modalities}" self._modality = next(iter(modalities)) - self._is_empty = any(elem.data is None for elem in self.values()) - @property def modality(self) -> str: return self._modality - @property - def is_empty(self) -> bool: - return self._is_empty - - def get_data(self) -> Optional[Mapping[str, NestedTensors]]: - if self._is_empty: - return None - - out_data = dict[str, NestedTensors]() - for key, elem in self.items(): - assert elem.data is not None, ( - f"Cannot get data of empty `elem[{key!r}]`") - out_data[key] = elem.data - - return out_data - - def require_data(self) -> Mapping[str, NestedTensors]: - if (data := self.get_data()) is None: - raise RuntimeError("Cannot get data of empty item") - - return data - - # These methods create a new item to avoid mutating cached items in place - def with_data(self, data: Mapping[str, NestedTensors]): - return MultiModalKwargsItem({ - key: replace(elem, data=data[key]) - for key, elem in self.items() - }) - - def without_data(self): - return MultiModalKwargsItem({ - key: replace(elem, data=None) - for key, elem in self.items() - }) + def get_data(self) -> dict[str, NestedTensors]: + return {key: elem.data for key, elem in self.items()} -# NOTE: UserDict is for V0 compatibility. -# V1 should access individual items via `get_item`. -class MultiModalKwargs(UserDict[str, NestedTensors]): +_I = TypeVar( + "_I", + MultiModalKwargsItem, + Optional[MultiModalKwargsItem], + default=MultiModalKwargsItem, +) + + +class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): """ - A dictionary that represents the keyword arguments to - [`torch.nn.Module.forward`][]. - - The metadata `items` enables us to obtain the keyword arguments - corresponding to each data item in - [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via - [`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and - [`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items]. + A dictionary of + [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s + by modality. """ @staticmethod @@ -742,43 +715,74 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): elems = [v[item_idx] for v in elems_in_modality.values()] items.append(MultiModalKwargsItem.from_elems(elems)) - return MultiModalKwargs.from_items(items) + return MultiModalKwargsItems.from_seq(items) @staticmethod + def from_seq(items: Sequence[MultiModalKwargsItem]): + items_by_modality = full_groupby(items, key=lambda x: x.modality) + return MultiModalKwargsItems(items_by_modality) + + def __getitem__(self, modality: str) -> Sequence[_I]: + if modality not in self: + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {set(self.keys())}") + + return super().__getitem__(modality) # type: ignore[return-value] + + def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": + elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError("Cannot build data from empty " + f"mm_items[{modality}][{i}]") + + for key, elem in item.items(): + elems_by_key[key].append(elem) + + return MultiModalKwargs({ + key: + elems[0].field.reduce_data(elems, pin_memory=pin_memory) + for key, elems in elems_by_key.items() + }) + + +MultiModalKwargsOptionalItems: TypeAlias = Union[ + MultiModalKwargsItems[MultiModalKwargsItem], + MultiModalKwargsItems[Optional[MultiModalKwargsItem]], +] + + +class MultiModalKwargs(UserDict[str, NestedTensors]): + """ + A dictionary that represents the keyword arguments to + [`torch.nn.Module.forward`][]. + """ + + @staticmethod + @deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_hf_inputs` and " + "access the tensor data using `.get_data()`.") + def from_hf_inputs( + hf_inputs: "BatchFeature", + config_by_key: Mapping[str, MultiModalFieldConfig], + ): + return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \ + .get_data() + + @staticmethod + @deprecated("`MultiModalKwargs.from_items` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_seq` and " + "access the tensor data using `.get_data()`.") def from_items( items: Sequence[MultiModalKwargsItem], *, pin_memory: bool = False, ): - """Construct a new - [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] - from multiple items.""" - elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) - for item in items: - for key, elem in item.items(): - elems_by_key[key].append(elem) - - data = { - key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() if len(elems) > 0 - } - - return MultiModalKwargs(data, items=items) - - def __init__( - self, - data: Mapping[str, NestedTensors], - *, - items: Optional[Sequence[MultiModalKwargsItem]] = None, - ) -> None: - super().__init__(data) - - items_by_modality = full_groupby(items or [], key=lambda x: x.modality) - self._items_by_modality = dict(items_by_modality) - - @property - def modalities(self): - return self._items_by_modality.keys() + return MultiModalKwargsItems.from_seq(items) \ + .get_data(pin_memory=pin_memory) @staticmethod def _try_stack(nested_tensors: NestedTensors, @@ -867,54 +871,24 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): return cast(BatchedTensorInputs, json_mapped) - def __delitem__(self, key: str) -> None: - super().__delitem__(key) + def __getitem__(self, key: str): + if key not in self: + raise KeyError(f"Keyword argument {key!r} not found. " + f"Available keys: {set(self.keys())}") - for items in self._items_by_modality.values(): - for item in items: - item.pop(key, None) + return super().__getitem__(key) def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False - if self._items_by_modality != other._items_by_modality: - return False - ks = self.keys() - return (ks == other.keys() - and all(nested_tensors_equal(self[k], other[k]) for k in ks)) + for k in self: + if k not in other: + return False + if not nested_tensors_equal(self[k], other[k]): + return False - def _validate_modality(self, method_name: str, modality: str) -> None: - if not self._items_by_modality: - raise RuntimeError( - f"`{method_name}` is not supported when " - "MultiModalKwargs is not initialized with `items`") - - if modality not in self._items_by_modality: - available_modalities = set(self._items_by_modality.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") - - def get_item_count(self, modality: str) -> int: - """Get the number of items belonging to a modality.""" - self._validate_modality("get_item_count", modality) - return len(self._items_by_modality[modality]) - - def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem: - """ - Get the keyword arguments corresponding to an item identified by - its modality and index. - """ - self._validate_modality("get_item", modality) - return self._items_by_modality[modality][item_index] - - def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: - """ - Get the keyword arguments corresponding to each item belonging to - a modality. - """ - self._validate_modality("get_items", modality) - return self._items_by_modality[modality] + return True MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] @@ -942,10 +916,10 @@ class MultiModalInputs(TypedDict): token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" - mm_kwargs: MultiModalKwargs + mm_kwargs: MultiModalKwargsOptionalItems """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: Optional["MultiModalHashDict"] + mm_hashes: "MultiModalHashDict" """The hashes of the multi-modal data.""" mm_placeholders: "MultiModalPlaceholderDict" diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 37f561274272b..88bb99529f200 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -16,7 +16,7 @@ from vllm.utils import LazyLoader, is_list_of from .audio import AudioResampler from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, ImageItem, ModalityData, MultiModalDataDict, - MultiModalFieldConfig, MultiModalKwargs, VideoItem) + MultiModalFieldConfig, MultiModalKwargsItems, VideoItem) _T = TypeVar("_T") _I = TypeVar("_I") @@ -157,19 +157,16 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], self.fields_config = fields_config self.required_fields = required_fields - self._kwargs = MultiModalKwargs.from_hf_inputs( + self._kwargs = MultiModalKwargsItems.from_hf_inputs( BatchFeature(dict(data)), fields_config, ) def get_count(self) -> int: - return self._kwargs.get_item_count(self.modality) + return len(self._kwargs[self.modality]) def get(self, index: int) -> Mapping[str, torch.Tensor]: - return { - k: v.data - for k, v in self._kwargs.get_item(self.modality, index).items() - } + return self._kwargs[self.modality][index].get_data() def get_processor_data(self) -> Mapping[str, object]: return {} diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 38c5d5d99f63e..6ecdf80d4aa6f 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, Sequence) -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, @@ -20,11 +20,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) from vllm.utils import flatten_2d_lists, full_groupby -from .cache import MultiModalCache from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, - MultiModalKwargsItem, PlaceholderRange) + MultiModalFieldConfig, MultiModalInputs, + MultiModalKwargsItem, MultiModalKwargsItems, + MultiModalKwargsOptionalItems, PlaceholderRange) from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, MultiModalDataParser) @@ -33,6 +33,7 @@ if TYPE_CHECKING: from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin + from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -43,10 +44,59 @@ PromptSeq = Union[str, list[int]] """A token sequence (list of token IDs) or text.""" +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + return encode_tokens(tokenizer, + text, + add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: Optional[bool] = None, +) -> str: + return decode_tokens(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: + if isinstance(seq, str): + return seq + + return _cached_decode(tokenizer, tuple(seq)) + + +def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: + if isinstance(seq, str): + return _cached_encode(tokenizer, seq, add_special_tokens=False) + + return seq + + +class _GetMatchIndex(Protocol): + + def __call__( + self, + tokenizer: AnyTokenizer, + prompt: PromptSeq, + start_idx: int = 0, + ) -> Optional[int]: + ... + + @dataclass class PromptIndex: """Resolves to an index in the prompt.""" - get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] + get_match_index: _GetMatchIndex class PromptIndexTargets: @@ -58,7 +108,7 @@ class PromptIndexTargets: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: 0) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0) @staticmethod def prefix(seq: PromptSeq) -> PromptIndex: @@ -69,7 +119,11 @@ class PromptIndexTargets: def get_match_index( tokenizer: AnyTokenizer, prompt: PromptSeq, + start_idx: int = 0, ) -> Optional[int]: + if start_idx != 0: + return None + prefix = seq if isinstance(prompt, str): @@ -95,14 +149,24 @@ class PromptIndexTargets: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: len(prompt)) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt)) -PromptTarget = Union[PromptSeq, PromptIndex] +UpdateTarget = Union[PromptSeq, PromptIndex] """ The token sequence or text to update. """ +PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget] +""" +Given the index of the processed item within +[`modality`][vllm.multimodal.processing.PromptUpdate.modality], +output the corresponding token sequence (or text). + +For convenience, you can directly pass in the token sequence (or text) +instead of a function if it does not depend on the input. +""" + @dataclass class PromptUpdateDetails(Generic[_S]): @@ -111,7 +175,8 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None + is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], + torch.Tensor]] = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -133,11 +198,12 @@ class PromptUpdateDetails(Generic[_S]): embed_text: str, ) -> "PromptUpdateDetails[_S]": - def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_ids = encode_tokens(full.tokenizer, embed_text) + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + embed_token_ids = encode_tokens(tokenizer, embed_text) + token_ids = _seq2tokens(tokenizer, full) return torch.isin( - torch.tensor(full.token_ids), + torch.tensor(token_ids), torch.tensor(embed_token_ids), ) @@ -148,10 +214,13 @@ class PromptUpdateDetails(Generic[_S]): seq: _S, embed_token_id: int, ) -> "PromptUpdateDetails[_S]": - return PromptUpdateDetails( - full=seq, - is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id, - ) + + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + token_ids = _seq2tokens(tokenizer, full) + + return torch.tensor(token_ids) == embed_token_id + + return PromptUpdateDetails(full=seq, is_embed=is_embed) PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] @@ -189,7 +258,7 @@ class PromptUpdate(ABC): modality: str """The modality for which the update is made.""" - target: PromptTarget + target: PromptUpdateTarget """The token sequence (or text) to update.""" @property @@ -204,10 +273,35 @@ class PromptUpdate(ABC): """Defines how to update the prompt.""" raise NotImplementedError - def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": - return BoundPromptUpdate( - _origin=self, - tokenizer=tokenizer, + def _resolve_target(self, item_idx: int) -> UpdateTarget: + target = self.target + if callable(target): + target = target(item_idx) + + return target + + def _resolve_content(self, item_idx: int) -> PromptUpdateDetails: + content = self.content + if callable(content): + content = content(item_idx) + + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) + + return content + + def resolve(self, item_idx: int) -> "ResolvedPromptUpdate": + """ + Given the index of the processed item within + [`modality`][vllm.multimodal.processing.PromptUpdate.modality], + output a copy of this object with its lazy attributes resolved. + """ + return ResolvedPromptUpdate( + modality=self.modality, + item_idx=item_idx, + mode=self.mode, + target=self._resolve_target(item_idx), + content=self._resolve_content(item_idx), ) @@ -354,30 +448,6 @@ class PromptReplacement(PromptUpdate): return UpdateMode.REPLACE -@lru_cache(maxsize=2048) -def _cached_encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: Optional[bool] = None, -) -> list[int]: - return encode_tokens(tokenizer, - text, - add_special_tokens=add_special_tokens) - - -@lru_cache(maxsize=2048) -def _cached_decode( - tokenizer: AnyTokenizer, - token_ids: tuple[int, ...], - *, - skip_special_tokens: Optional[bool] = None, -) -> str: - return decode_tokens(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) - - class _HasModalityAttr(Protocol): modality: str @@ -398,126 +468,103 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: return full_groupby(values, key=lambda x: x.modality) -@dataclass -class _BoundPromptSequence: - """ - A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound - to a tokenizer to automatically - convert between token sequence and text representations. - """ - tokenizer: AnyTokenizer = field(repr=False) +class PromptTargetMatch(NamedTuple): + start_idx: int + end_idx: int - _text: Optional[str] - _token_ids: Optional[list[int]] - @staticmethod - def from_seq( +@dataclass(frozen=True) +class ResolvedPromptUpdate: + """ + A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its + lazy attributes resolved, apart from those related to tokenization. + """ + + modality: str + """The modality for which the update is made.""" + + item_idx: int + """The index within `modality` of the item this update pertains to.""" + + mode: UpdateMode + """Defines how to update the prompt.""" + + target: UpdateTarget + """The token sequence (or text) to update.""" + + content: PromptUpdateDetails = field(repr=False) + """The placeholder tokens that are part of the update.""" + + def iter_token_matches( + self, + prompt: list[int], tokenizer: AnyTokenizer, - seq: PromptSeq, - ) -> "_BoundPromptSequence": - return _BoundPromptSequence( - tokenizer=tokenizer, - _text=seq if isinstance(seq, str) else None, - _token_ids=seq if isinstance(seq, list) else None, - ) - - def __post_init__(self) -> None: - if self._text is None and self._token_ids is None: - raise ValueError("At least one of 'text' and 'token_ids' must be " - "specified") - - @property - def text(self) -> str: - if self._text is None: - assert self._token_ids is not None - self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) - - return self._text - - @property - def token_ids(self) -> list[int]: - if self._token_ids is None: - assert self._text is not None - self._token_ids = _cached_encode(self.tokenizer, - self._text, - add_special_tokens=False) - - return self._token_ids - - -@dataclass -class _BoundPromptContent: - full: _BoundPromptSequence - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] - - -@dataclass -class BoundPromptUpdate: - """ - A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound - to a tokenizer to automatically convert - [`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of - [`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content] - between token sequence and text representations. - """ - _origin: PromptUpdate - tokenizer: AnyTokenizer = field(repr=False) - - def __post_init__(self) -> None: - self._content_cache = dict[int, _BoundPromptContent]() - - @property - def modality(self) -> str: - return self._origin.modality - - @property - def target(self) -> Union[_BoundPromptSequence, PromptIndex]: - """The token sequence (or text) to update.""" - target = self._origin.target + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target if isinstance(target, PromptIndex): - return target + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) - return _BoundPromptSequence.from_seq(self.tokenizer, target) + return - @property - def content(self) -> PromptUpdateContent: - """The placeholder tokens that are part of the update.""" - return self._origin.content + target_token_ids = _seq2tokens(tokenizer, target) - @property - def mode(self) -> UpdateMode: - """Defines how to update the prompt.""" - return self._origin.mode + for match in iter_token_matches(prompt, + target_token_ids, + start_idx=start_idx): + yield PromptTargetMatch(match.start_idx, match.end_idx) - def get_content(self, item_idx: int) -> _BoundPromptContent: - """ - Given the index of the processed item within - [`modality`][vllm.multimodal.processing.PromptUpdate.modality], - output the token sequence (or text) to update. - """ - content = self.content - if callable(content): - cache_key = item_idx - if cache_key in self._content_cache: - return self._content_cache[cache_key] + def iter_text_matches( + self, + prompt: str, + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target - content = content(item_idx) - else: - cache_key = None + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) + return + + target_text = _seq2text(tokenizer, target) + + for match in re.finditer(re.escape(target_text), prompt, + pos=start_idx): + yield PromptTargetMatch(match.start(), match.end()) + + def iter_matches( + self, + prompt: Union[list[int], str], + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + if isinstance(prompt, str): + return self.iter_text_matches(prompt, + tokenizer, + start_idx=start_idx) + + return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) + + def with_target(self, target: UpdateTarget): + return replace(self, target=target) + + def with_content(self, content: PromptUpdateInfo): if not isinstance(content, PromptUpdateDetails): content = PromptUpdateDetails.from_seq(content) - bound_full = _BoundPromptSequence.from_seq(self.tokenizer, - content.full) - bound_content = _BoundPromptContent(full=bound_full, - is_embed=content.is_embed) - - if cache_key is not None: - self._content_cache[cache_key] = bound_content - - return bound_content + return replace(self, content=content) class _TokenMatch(NamedTuple): @@ -528,6 +575,8 @@ class _TokenMatch(NamedTuple): def iter_token_matches( token_ids: list[int], match_ids: list[int], + *, + start_idx: int = 0, ) -> Generator[_TokenMatch]: """ Yield each occurrence of `match_ids` in `token_ids`. @@ -540,7 +589,6 @@ def iter_token_matches( if match_len == 0: return - start_idx = 0 while start_idx < prompt_len - match_len + 1: end_idx = start_idx + match_len @@ -580,68 +628,6 @@ def replace_token_matches( return flatten_2d_lists(out_seqs) -@dataclass(repr=False) -class PromptTargetMatch(ABC): - _origin: BoundPromptUpdate - - @property - def modality(self) -> str: - return self._origin.modality - - @property - @abstractmethod - def start_idx(self) -> int: - raise NotImplementedError - - @property - @abstractmethod - def end_idx(self) -> int: - raise NotImplementedError - - def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") - - -@dataclass(repr=False) -class _PromptTargetIndexMatch(PromptTargetMatch): - match_idx: int - - @property - def start_idx(self) -> int: - return self.match_idx - - @property - def end_idx(self) -> int: - return self.match_idx - - -@dataclass(repr=False) -class _PromptTargetTokenMatch(PromptTargetMatch): - match: _TokenMatch - - @property - def start_idx(self) -> int: - return self.match.start_idx - - @property - def end_idx(self) -> int: - return self.match.end_idx - - -@dataclass(repr=False) -class _PromptTargetTextMatch(PromptTargetMatch): - match: re.Match[str] - - @property - def start_idx(self) -> int: - return self.match.start() - - @property - def end_idx(self) -> int: - return self.match.end() - - @dataclass class PlaceholderFeaturesInfo: modality: str @@ -664,163 +650,161 @@ class PlaceholderFeaturesInfo: ) -def find_token_matches( - prompt: list[int], - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" - - def get_matches(update: BoundPromptUpdate): - target = update.target - - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] - - return [_PromptTargetIndexMatch(update, match_idx)] - - return [ - _PromptTargetTokenMatch(update, match) - for match in iter_token_matches(prompt, target.token_ids) - ] - - return [ - match for update in prompt_updates for match in get_matches(update) - ] +_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]] -def find_text_matches( - prompt: str, - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" +def _find_matches( + prompt: _S, + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, + *, + prev_end_idx: int = 0, + current_result: "MultiModalPromptUpdatesApplyResult", +) -> tuple[Optional[UpdateMode], list[_MatchToApply]]: + mode: Optional[UpdateMode] = None + mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]() - def get_matches(update: BoundPromptUpdate): - target = update.target + for modality, modality_updates in mm_prompt_updates.items(): + for item_idx, item_updates in enumerate(modality_updates): + if current_result[modality][item_idx] is not None: + continue # Updates have already been applied for this item - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] + for update_idx, update in enumerate(item_updates): + if (modality, item_idx) in mm_matches: + break # Already found a match for this item - return [_PromptTargetIndexMatch(update, match_idx)] + for match in update.iter_matches( + prompt, + tokenizer, + start_idx=prev_end_idx, + ): + # All matches should share the same mode + if mode is None: + mode = update.mode + elif mode != update.mode: + continue - return [ - _PromptTargetTextMatch(update, match) - for match in re.finditer(re.escape(target.text), prompt) - ] + mm_matches[(modality, item_idx)] = match, update_idx + break # Get only the first valid match per item - return [ - match for update in prompt_updates for match in get_matches(update) - ] + # Prioritize earlier matches + matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0]) + # To avoid conflicts, only replace one non-empty item at a time + if mode == UpdateMode.REPLACE: + matches_to_apply_ = list[_MatchToApply]() + has_non_empty_matches = False -def _resolve_matches( - prompt: PromptSeq, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], -) -> list[PromptTargetMatch]: - """ - Resolve `mm_matches` to ensure that there are no overlapping matches, - and sort them such that earlier matches take priority over later ones. - """ - matches = [m for matches in mm_matches.values() for m in matches] + for item in matches_to_apply: + _, (match, _) = item + if match.start_idx == match.end_idx: + matches_to_apply_.append(item) + elif not has_non_empty_matches: + has_non_empty_matches = True + matches_to_apply_.append(item) - seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt) + matches_to_apply = matches_to_apply_ - for match in matches: - for idx in range(match.start_idx, match.end_idx): - if seen_matches[idx] is not None: - raise ValueError("Found overlapping matches " - f"({seen_matches[idx]} and {match}) " - f"at index={idx} of prompt={prompt}") - - seen_matches[idx] = match - - return sorted(matches, key=lambda x: x.start_idx) + return mode, matches_to_apply def _apply_matches( prompt: _S, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[_S]: - """Apply the updates in `mm_matches` to `prompt`.""" + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: + prompt_len = len(prompt) + out_seqs = list[Union[str, list[int]]]() - prev_end_idx = 0 - next_idx_by_modality = defaultdict[str, int](lambda: 0) + out_result: MultiModalPromptUpdatesApplyResult = { + m: [None] * len(items) + for m, items in mm_prompt_updates.items() + } - for match in _resolve_matches(prompt, mm_matches): - modality = match.modality + start_idx = prev_end_idx = 0 + while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt + found = False - item_start_idx = next_idx_by_modality[modality] - max_item_count = mm_item_counts.get(modality, 0) - if item_start_idx >= max_item_count: - continue + mode, matches_to_apply = _find_matches( + prompt, + mm_prompt_updates, + tokenizer, + prev_end_idx=prev_end_idx, + current_result=out_result, + ) - start_idx = match.start_idx - end_idx = match.end_idx - origin = match._origin - mode = origin.mode + if mode is not None: + for (modality, item_idx), (match, update_idx) in matches_to_apply: + found = True - if mode == UpdateMode.INSERT: - out_seqs.append(prompt[prev_end_idx:end_idx]) - num_inserts = max_item_count - elif mode == UpdateMode.REPLACE: - out_seqs.append(prompt[prev_end_idx:start_idx]) - num_inserts = max_item_count if start_idx == end_idx else 1 - else: - assert_never(mode) + matched_update = mm_prompt_updates[modality][item_idx][ + update_idx] + matched_content = matched_update.content.full - item_end_idx = min(item_start_idx + num_inserts, max_item_count) + if mode == UpdateMode.INSERT: + end_idx_to_insert = match.end_idx + elif mode == UpdateMode.REPLACE: + end_idx_to_insert = match.start_idx + else: + assert_never(mode) - for item_idx in range(item_start_idx, item_end_idx): - content = origin.get_content(item_idx) - insert_seq = (content.full.text if isinstance(prompt, str) else - content.full.token_ids) + out_seqs.append(prompt[prev_end_idx:end_idx_to_insert]) + out_seqs.append( + _seq2text(tokenizer, matched_content + ) if isinstance(prompt, str) else _seq2tokens( + tokenizer, matched_content)) + out_result[modality][item_idx] = update_idx - out_seqs.append(insert_seq) + # Exclude overlapping matches + start_idx = prev_end_idx = match.end_idx - prev_end_idx = end_idx - next_idx_by_modality[modality] += item_end_idx - item_start_idx + if not found: + start_idx += 1 out_seqs.append(prompt[prev_end_idx:]) - return cast(list[_S], out_seqs) + return cast(list[_S], out_seqs), out_result def apply_token_matches( prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[int]: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, + tokenizer) - return flatten_2d_lists(token_id_seqs) + return flatten_2d_lists(token_id_seqs), result def apply_text_matches( prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> str: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - texts = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer) - return "".join(texts) + return "".join(texts), result def _iter_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Iterable[PlaceholderFeaturesInfo]: """ Yield each set of placeholder tokens found in `prompt`. @@ -832,6 +816,8 @@ def _iter_placeholders( Note that empty matches are ignored. """ prompt_len = len(prompt) + mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} + item_idx_by_modality = defaultdict[str, int](lambda: 0) start_idx = 0 @@ -843,9 +829,9 @@ def _iter_placeholders( if item_idx >= mm_item_counts.get(modality, 0): continue - for update_info in modality_updates: - content = update_info.get_content(item_idx) - content_tokens_full = content.full.token_ids + for update in modality_updates[item_idx]: + content = update.content + content_tokens_full = _seq2tokens(tokenizer, content.full) content_len_full = len(content_tokens_full) end_idx_full = start_idx + content_len_full @@ -855,7 +841,8 @@ def _iter_placeholders( if prompt[start_idx:end_idx_full] == content_tokens_full: content_is_embed = content.is_embed if content_is_embed is not None: - content_is_embed = content_is_embed(content.full) + content_is_embed = content_is_embed( + tokenizer, content.full) yield PlaceholderFeaturesInfo( modality=modality, @@ -879,29 +866,14 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) + it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) return dict(full_groupby_modality(it)) -class ProcessingCache(MultiModalCache): - - def __init__(self, capacity_gb: float) -> None: - super().__init__() - - self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) - - self.get = self._cache.get - self.put = self._cache.put - self.reset = self._cache.clear - - -_CacheItemOrHash = Union[MultiModalKwargsItem, str] - - class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" @@ -985,9 +957,29 @@ _I = TypeVar("_I", bound=BaseProcessingInfo) MultiModalHashes = dict[str, list[str]] """ A collection of hashes with a similar structure as -[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]. +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ +MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]] +""" +A collection of prompt updates with a similar structure as +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. +""" + +MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]] +""" +For an item `MultiModalPromptUpdates[k][i]`, +`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the +`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the +`ResolvedPromptUpdate` instances have been applied. +""" + + +class MultiModalProcessingInfo(NamedTuple): + kwargs: MultiModalKwargsOptionalItems + hashes: MultiModalHashes + prompt_updates: MultiModalPromptUpdates + class BaseMultiModalProcessor(ABC, Generic[_I]): """ @@ -996,11 +988,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Not to be confused with `transformers.ProcessorMixin`. """ - def __init__(self, - info: _I, - dummy_inputs: "BaseDummyInputsBuilder[_I]", - *, - cache: Optional[ProcessingCache] = None) -> None: + def __init__( + self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional["BaseMultiModalProcessorCache"] = None, + ) -> None: super().__init__() self.info = info @@ -1095,7 +1089,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: """ Given the original multi-modal items for this modality @@ -1113,14 +1107,60 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ raise NotImplementedError + def _bind_and_group_updates( + self, + prompt_updates: Sequence[PromptUpdate], + mm_item_counts: Mapping[str, int], + ) -> MultiModalPromptUpdates: + return { + modality: [[update.resolve(item_idx) for update in updates] + for item_idx in range(mm_item_counts.get(modality, 0))] + for modality, updates in full_groupby_modality(prompt_updates) + } + + def _get_mm_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> MultiModalPromptUpdates: + unbound_prompt_updates = self._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates, + mm_items.get_all_counts(), + ) + + for modality, prompt_updates in mm_prompt_updates.items(): + for item_idx, item_prompt_updates in enumerate(prompt_updates): + if len(item_prompt_updates) > 1: + logger.warning_once( + "Detected %d prompt updates for `mm_items[%r][%s]`. " + "Multiple prompt updates per item is now " + "deprecated and may be removed in v0.13. " + "Instead, please specify dynamic update targets " + "in the same prompt update definition by passing " + "a function to `PromptUpdate.target`.", + len(prompt_updates), + modality, + item_idx, + ) + + return mm_prompt_updates + def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(mm_prompt_updates, new_token_ids, - mm_item_counts) + tokenizer = self.info.get_tokenizer() + + return find_mm_placeholders(new_token_ids, mm_prompt_updates, + tokenizer) def _get_hf_mm_data( self, @@ -1311,32 +1351,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return prompt_ids, mm_processed_data, False - def _get_cache_missing_items( - self, - cache: ProcessingCache, - mm_data_items: MultiModalDataItems, - mm_hashes: MultiModalHashes, - ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]: - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = { - modality: [(h if (v := cache.get(h)) is None else v) - for h in hashes] - for modality, hashes in mm_hashes.items() - } - - mm_missing_idxs = { - modality: [ - idx for idx, item_or_hash in enumerate(items_or_hashes) - if isinstance(item_or_hash, str) - ] - for modality, items_or_hashes in mm_cache_items_or_hashes.items() - } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } - - return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data) - def _hash_mm_items( self, mm_items: MultiModalDataItems, @@ -1357,30 +1371,92 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): for modality, items in mm_items.items() } + def _get_cache_missing_items( + self, + cache: "BaseMultiModalProcessorCache", + mm_data_items: MultiModalDataItems, + mm_hashes: MultiModalHashes, + ) -> MultiModalDataItems: + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() + } + + mm_missing_idxs = { + modality: [ + idx for idx, item_is_cached in enumerate(items_is_cached) + if not item_is_cached + ] + for modality, items_is_cached in mm_is_cached.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + + return self._to_mm_items(mm_missing_data) + + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + """ + Override this if other attributes of `ResolvedPromptUpdate` + also need to be recomputed after retrieving from the cache. + """ + return replace(cached_update, item_idx=new_item_idx) + def _merge_mm_kwargs( self, - cache: ProcessingCache, - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], - mm_missing_kwargs: MultiModalKwargs, - ) -> dict[str, list[MultiModalKwargsItem]]: + cache: "BaseMultiModalProcessorCache", + mm_hashes: MultiModalHashes, + mm_missing_kwargs: MultiModalKwargsItems, + mm_missing_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]: + # Need to calculate this at the beginning to avoid skipping cache logic + # for subsequently repeated items in the same modality + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() + } + mm_missing_next_idx = defaultdict[str, int](lambda: 0) - merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) - for modality, items_or_hashes in mm_cache_items_or_hashes.items(): - for item_or_hash in items_or_hashes: - if isinstance(item_or_hash, str): - kw_item = mm_missing_kwargs.get_item( - modality, - mm_missing_next_idx[modality], - ) - cache.put(item_or_hash, kw_item) + merged_kwargs = defaultdict[str, + list[Optional[MultiModalKwargsItem]]](list) + merged_prompt_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, hashes in mm_hashes.items(): + missing_kwargs = mm_missing_kwargs.get(modality, []) + missing_prompt_updates = mm_missing_prompt_updates.get( + modality, []) + + for item_idx, item_hash in enumerate(hashes): + kwargs: Optional[MultiModalKwargsItem] + if not mm_is_cached[modality][item_idx]: + missing_next_idx = mm_missing_next_idx[modality] + kwargs = missing_kwargs[missing_next_idx] + updates = missing_prompt_updates[missing_next_idx] + mm_missing_next_idx[modality] += 1 + + item = kwargs, updates else: - kw_item = item_or_hash + item = None - merged_items[modality].append(kw_item) + kwargs, updates = cache.get_and_update_item(item, item_hash) - return dict(merged_items) + merged_kwargs[modality].append(kwargs) + merged_prompt_updates[modality].append([ + self._recompute_cached_prompt_update(update, item_idx) + for update in updates + ]) + + mm_kwargs = MultiModalKwargsItems(merged_kwargs) + mm_prompt_updates = dict(merged_prompt_updates) + + return mm_kwargs, mm_prompt_updates def _apply_hf_processor( self, @@ -1388,9 +1464,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, mm_processed_data, @@ -1403,17 +1477,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): enable_hf_prompt_update=True, ) - mm_kwargs = MultiModalKwargs.from_hf_inputs( + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_processed_data, self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs), ) - mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, - tokenization_kwargs) - if return_mm_hashes else None) + mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, + tokenization_kwargs) - return prompt_ids, mm_kwargs, mm_hashes, is_update_applied + mm_prompt_updates = self._get_mm_prompt_updates( + mm_data_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + + mm_info = MultiModalProcessingInfo( + kwargs=mm_kwargs, + hashes=mm_hashes, + prompt_updates=mm_prompt_updates, + ) + + return prompt_ids, mm_info, is_update_applied def _cached_apply_hf_processor( self, @@ -1421,9 +1506,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, caching the results and reusing cached results. @@ -1437,22 +1520,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs) - ( - mm_cache_items_or_hashes, - mm_missing_data_items, - ) = self._get_cache_missing_items( + + mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, mm_hashes=mm_hashes, ) - mm_hashes_to_return = mm_hashes if return_mm_hashes else None - # NOTE: `prompt` does not correspond to `mm_missing_data_items`, # so we can't apply prompt updates until the new multimodal # items are combined with the cached multimodal items @@ -1468,66 +1546,60 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): enable_hf_prompt_update=False, ) - mm_missing_kwargs = MultiModalKwargs.from_hf_inputs( + mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_missing_processed_data, self._get_mm_fields_config(mm_missing_processed_data, hf_processor_mm_kwargs), ) - mm_cache_items_merged = self._merge_mm_kwargs( - cache, - mm_cache_items_or_hashes=mm_cache_items_or_hashes, - mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates = self._get_mm_prompt_updates( + mm_missing_data_items, + hf_processor_mm_kwargs, + mm_missing_kwargs, ) - mm_kwargs = MultiModalKwargs.from_items([ - item for cache_items in mm_cache_items_merged.values() - for item in cache_items - ]) + mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( + cache, + mm_hashes=mm_hashes, + mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates=mm_missing_prompt_updates, + ) - return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied + mm_info = MultiModalProcessingInfo( + kwargs=mm_kwargs, + hashes=mm_hashes, + prompt_updates=mm_prompt_updates, + ) - def _bind_and_group_updates( - self, - prompt_updates: Sequence[PromptUpdate], - ) -> dict[str, Sequence[BoundPromptUpdate]]: - tokenizer = self.info.get_tokenizer() - - it = (update.bind(tokenizer) for update in prompt_updates) - return dict(full_groupby_modality(it)) + return prompt_ids, mm_info, is_update_applied def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - return apply_token_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_token_matches(prompt, mm_prompt_updates, tokenizer) def _apply_text_matches( self, prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> str: - return apply_text_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[str, MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_text_matches(prompt, mm_prompt_updates, tokenizer) def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() - mm_token_matches = { - modality: find_token_matches(token_ids, updates) - for modality, updates in mm_prompt_updates.items() - } - mm_match_counts = { - modality: len(matches) - for modality, matches in mm_token_matches.items() - } + new_token_ids, match_result = self._apply_token_matches( + token_ids, + mm_prompt_updates, + ) # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -1540,59 +1612,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # of the search text in the prompt, we instead perform string-based # updates on the decoded token IDs, then encode them back. if all( - mm_match_counts.get(modality, 0) >= item_count - for modality, item_count in mm_item_counts.items() - ): # yapf: disable - token_ids = self._apply_token_matches( - token_ids, - mm_token_matches, - mm_item_counts, - ) - - text = decode_tokens(tokenizer, token_ids) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_token_matches.items() - } + all(update_idx is not None for update_idx in update_idxs) + for update_idxs in match_result.values()): + new_text = decode_tokens(tokenizer, new_token_ids) else: - text = decode_tokens(tokenizer, token_ids) - - mm_text_matches = { - modality: find_text_matches(text, updates) - for modality, updates in mm_prompt_updates.items() - } - text = self._apply_text_matches( - text, - mm_text_matches, - mm_item_counts, + new_text, match_result = self._apply_text_matches( + decode_tokens(tokenizer, token_ids), + mm_prompt_updates, ) - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_text_matches.items() - } + new_token_ids = encode_tokens( + tokenizer, + new_text, + add_special_tokens=False, + ) + + matched_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, update_idxs in match_result.items(): + for item_idx, update_idx in enumerate(update_idxs): + assert update_idx is not None, ( + "Failed to apply prompt replacement for " + f"mm_items[{modality!r}][{item_idx}]") + + matched_updates[modality].append( + [mm_prompt_updates[modality][item_idx][update_idx]]) placeholders = self._find_mm_placeholders( - matched_updates, - token_ids, - mm_item_counts, + new_token_ids, + dict(matched_updates), ) - return token_ids, text, placeholders + return new_token_ids, new_text, placeholders def _validate_mm_kwargs( self, - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsOptionalItems, mm_item_counts: Mapping[str, int], ) -> None: for modality, item_count in mm_item_counts.items(): - if modality in mm_kwargs.modalities: - items = mm_kwargs.get_items(modality) - else: - items = [] + items = mm_kwargs.get(modality, []) if len(items) != item_count: raise RuntimeError( @@ -1628,27 +1687,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsOptionalItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) @@ -1662,7 +1712,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) @@ -1674,7 +1723,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1696,23 +1744,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ( prompt_ids, - mm_kwargs, - mm_hashes, + mm_info, is_update_applied, ) = self._cached_apply_hf_processor( prompt, mm_items, hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) # NOTE: tokenization_kwargs are not required to init processor prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, prompt_ids=prompt_ids, - mm_kwargs=mm_kwargs, + mm_kwargs=mm_info.kwargs, + mm_prompt_updates=mm_info.prompt_updates, is_update_applied=is_update_applied, ) @@ -1725,8 +1771,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): type="multimodal", prompt=prompt, prompt_token_ids=prompt_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, + mm_kwargs=mm_info.kwargs, + mm_hashes=mm_info.hashes, mm_placeholders=mm_placeholder_ranges, ) @@ -1789,7 +1835,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalEncDecInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1804,7 +1849,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - return_mm_hashes, ) return self._get_enc_dec_inputs( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index d876887fc155d..ffc69a2db60a4 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -13,7 +13,7 @@ import vllm.envs as envs from vllm.logger import init_logger from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalKwargs, + MultiModalInputs, MultiModalKwargsOptionalItems, MultiModalPlaceholderDict) from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, EncDecMultiModalProcessor) @@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple): """Dummy data used for profiling.""" prompt_token_ids: list[int] - multi_modal_data: MultiModalKwargs + multi_modal_data: MultiModalKwargsOptionalItems multi_modal_placeholders: MultiModalPlaceholderDict @@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]): if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. + # NOTE: Whisper and Donut allows total_len > seq_len. elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index ded56cca80999..38adbf8f3536a 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from functools import lru_cache from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn @@ -13,8 +12,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - ProcessingCache) +from .cache import (BaseMultiModalProcessorCache, + processor_only_cache_from_config) +from .processing import BaseMultiModalProcessor, BaseProcessingInfo from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) @@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]): info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[_I]: ... @@ -80,20 +80,13 @@ class _ProcessorFactories(Generic[_I]): self, ctx: InputProcessingContext, *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ): info = self.info(ctx) dummy_inputs_builder = self.dummy_inputs(info) return self.processor(info, dummy_inputs_builder, cache=cache) -# Make sure a different cache is used for each model config -# NOTE: ModelConfig is not hashable so it cannot be passed directly -@lru_cache(maxsize=1) -def _get_processor_cache(model_id: str, capacity_gb: int): - return ProcessingCache(capacity_gb) if capacity_gb > 0 else None - - class MultiModalRegistry: """ A registry that dispatches data processing according to the model. @@ -103,31 +96,6 @@ class MultiModalRegistry: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - def _get_processor_cache(self, model_config: "ModelConfig"): - model_id = model_config.model - capacity_gb = model_config.mm_processor_cache_gb - return _get_processor_cache(model_id, capacity_gb) - - def reset_processor_cache(self, model_config: "ModelConfig") -> bool: - """Reset the multi-modal processing cache.""" - if processor_cache := self._get_processor_cache(model_config): - processor_cache.reset() - - return True # Success - - def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool: - """Whether the multi-modal input cache should be enabled. - NOTE: This is put under MultiModalRegistry on purpose to respect - text-only mode for multimodal models. - """ - - if not self.supports_multimodal_inputs(model_config): - return False - - mm_config = model_config.get_multimodal_config() - - return mm_config.mm_processor_cache_gb > 0 - def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: """ Checks if the model supports multimodal inputs. @@ -157,6 +125,8 @@ class MultiModalRegistry: def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -165,11 +135,11 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) return profiler.get_mm_max_contiguous_tokens( seq_len, @@ -182,6 +152,8 @@ class MultiModalRegistry: def get_max_tokens_per_item_by_nonzero_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -192,15 +164,19 @@ class MultiModalRegistry: This is currently directly used only in V1 for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() if mm_limits[key] > 0 } + # TODO: Remove once V0 is gone def get_max_tokens_by_modality( self, model_config: "ModelConfig", @@ -209,14 +185,19 @@ class MultiModalRegistry: Get the maximum number of tokens from each modality for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + cache = processor_only_cache_from_config(model_config, self) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: mm_limits[key] * max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() } + # TODO: Remove once V0 is gone def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens @@ -227,6 +208,8 @@ class MultiModalRegistry: def get_mm_limits_per_prompt( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality @@ -235,7 +218,7 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() @@ -303,7 +286,7 @@ class MultiModalRegistry: model_config: "ModelConfig", *, tokenizer: Optional[AnyTokenizer] = None, - disable_cache: Optional[bool] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. @@ -311,15 +294,10 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if disable_cache is None: - disable_cache = not model_config.enable_mm_processor_cache - model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls] ctx = self._create_processing_ctx(model_config, tokenizer) - cache = None if disable_cache else self._get_processor_cache( - model_config) return factories.build_processor(ctx, cache=cache) @@ -328,13 +306,15 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyDecoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) @@ -352,13 +332,15 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyEncoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) @@ -372,3 +354,22 @@ class MultiModalRegistry: ) return dummy_data + + def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: + """ + Get the maximum length of the encoder input for encoder-decoder models. + """ + if not model_config.is_encoder_decoder: + return 0 + max_tokens = self.\ + get_max_tokens_per_item_by_nonzero_modality(model_config) + if not max_tokens: + # TODO - this function assumes encoder-decoder models are + # multimodal. This will need to change when adding support for more + # than whisper. + return 0 + assert len(max_tokens) == 1, "Encoder-decoder models are expected \ + to implement the multimodal interface with at most one modality." + + first_modality = next(iter(max_tokens)) + return max_tokens[first_modality] diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index f914d0dc6c5e7..834b2189e4bed 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -3,6 +3,8 @@ import asyncio import atexit +import itertools +import math from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from itertools import groupby @@ -32,11 +34,13 @@ _M = TypeVar("_M") if TYPE_CHECKING: from .inputs import (BatchedTensorInputs, MultiModalKwargs, - MultiModalKwargsItem, MultiModalPlaceholderDict) + MultiModalKwargsItem, MultiModalKwargsItems, + MultiModalPlaceholderDict) else: BatchedTensorInputs = Any MultiModalKwargs = Any MultiModalKwargsItem = Any + MultiModalKwargsItems = Any MultiModalPlaceholderDict = Any global_thread_pool = ThreadPoolExecutor( @@ -359,18 +363,20 @@ def argsort_mm_positions( "`group_mm_kwargs_by_modality` and will be removed in v0.13. " "Please use `group_mm_kwargs_by_modality` instead.") def group_mm_inputs_by_modality( - mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]: + mm_inputs: list[MultiModalKwargsItems] +) -> list[list[MultiModalKwargsItems]]: if not mm_inputs: return [] - def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]: + def modality_group_func( + mm_input: MultiModalKwargsItems) -> Union[str, int]: # If the input has multiple modalities, return a id as the unique key # for the mm_input input. - if len(mm_input.modalities) > 1: + if len(mm_input) > 1: return id(mm_input) - elif len(mm_input.modalities) == 1: - return list(mm_input.modalities)[0] + elif len(mm_input) == 1: + return next(iter(mm_input.keys())) # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty, # this is used to make InternVL with legacy pipeline still work with v1. @@ -397,17 +403,19 @@ def group_mm_kwargs_by_modality( Yields: A tuple `(modality, num_items, grouped_kwargs)`. """ - from vllm.multimodal.inputs import MultiModalKwargs + from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): items_lst = list(items) - # mm_kwargs_group = MultiModalKwargs.from_items(items_lst, - # pin_memory=pin_memory) + # mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \ + # .get_data(pin_memory=pin_memory) # if device is not None: - # mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device), - # mm_kwargs_group.data) + # mm_kwargs_group = json_map_leaves( + # lambda x: x.to(device=device), + # mm_kwargs_group, + # ) # TODO: Once V0 is removed, we can use the merging logic above # to avoid creating an extra batch dimension (except for fields @@ -415,7 +423,10 @@ def group_mm_kwargs_by_modality( # We will also need to update each model to remove `flatten_bn`. mm_kwargs_group = MultiModalKwargs.as_kwargs( MultiModalKwargs.batch( - [MultiModalKwargs.from_items([item]) for item in items_lst], + [ + MultiModalKwargsItems.from_seq([item]).get_data() + for item in items_lst + ], pin_memory=pin_memory, ), device=device, @@ -450,12 +461,227 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, num_chunks_per_rank, ...] vision_embeddings = vision_model(image_input_per_rank) + # Ensure tensor is contiguous before all_gather + vision_embeddings = vision_embeddings.contiguous() vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0) vision_embeddings = vision_embeddings[:num_chunks, ...] return vision_embeddings +def get_load_balance_assignment( + sizes: list[int], + num_gpus: int = 2, +) -> tuple[list[int], list[int], list[int]]: + """ + Generate load balancing assignment and metadata + for distributing data across GPUs. + The load is determined by the total image sizes, + not the number of images. + + Args: + sizes: The size of each image + num_gpus: Number of GPUs to balance across + + Returns: + shuffle_indices: + Indices to reorder data for balanced loading + gpu_sample_counts: + Number of samples assigned to each GPU + grouped_sizes_per_gpu: + Total size assigned to each GPU + + Example: + ``` + sizes = [1000, 100, 200, 50] + num_gpus=2 + ``` + + """ + + n_samples = len(sizes) + + # Handle edge cases + if n_samples == 0: + return [], [0] * num_gpus, [0] * num_gpus + + # Use greedy algorithm - balance by total size, not sample count + gpu_assignments = [list[int]() for _ in range(num_gpus)] + gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count + + # Sort indices by size (largest first for better load balancing) + # sizes = [1000, 100, 200, 50] + # large_to_small_indices = [0, 2, 1, 3] + large_to_small_indices = sorted(range(n_samples), + key=lambda i: sizes[i], + reverse=True) + + for idx in large_to_small_indices: + # Find GPU with minimum current load (by total size) + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + gpu_assignments[min_gpu].append(idx) + gpu_loads[min_gpu] += sizes[idx] + + # Create shuffle indices and counts + shuffle_indices = list[int]() + gpu_sample_counts = list[int]() + for gpu_id in range(num_gpus): + # GPU_0 = [1000] = [0] + # GPU_1 = [200, 100, 50] = [2, 1, 3] + # shuffle_indices = [0, 2, 1, 3] + shuffle_indices.extend(gpu_assignments[gpu_id]) + # GPU_0 = [1] + # GPU_1 = [3] + # gpu_sample_counts = [1, 3] + gpu_sample_counts.append(len(gpu_assignments[gpu_id])) + + return (shuffle_indices, gpu_sample_counts, gpu_loads) + + +def run_dp_sharded_mrope_vision_model( + vision_model: torch.nn.Module, + pixel_values: torch.Tensor, + grid_thw_list: list[list[int]], +) -> tuple[torch.Tensor, ...]: + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the + first dimension and run the vision model. + This function is used to run the vision model with mrope. + + Args: + vision_model (torch.nn.Module): Vision model. + pixel_values (torch.Tensor): Image/Video input tensor. + grid_thw_list: List of grid dimensions for each image + Returns: + torch.Tensor: Output image embeddings + + Example: + ``` + vision_model.out_hidden_size = 64 + vision_model.spatial_merge_size = 2 + pixel_values.shape = (1350, channel) + grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] + tp_size=2 + ``` + + """ + tp_size = get_tensor_model_parallel_world_size() + + # GPU_0 tp_rank_local = 0 + # GPU_1 tp_rank_local = 1 + tp_rank_local = get_tensor_model_parallel_rank() + + # patches_per_image = [1000, 100, 200, 50] + patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] + # patches_per_image = [0, 1000, 1100, 1300, 1350] + cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] + + # Get load balancing assignment with all metadata + # image_to_tp_rank = [0, 2, 1, 3] + # gpu_sample_counts = [1, 3] + # grouped_pixel_values_len = [1000, 350] + (image_to_tp_rank, gpu_sample_counts, + grouped_pixel_values_len) = get_load_balance_assignment( + patches_per_image, tp_size) + + # cu_gpu_sample_counts = [0, 1, 4] + cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] + + # GPU_0 image_idxs_local = [0] + # GPU_1 image_idxs_local = [2, 1, 3] + image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: + cum_gpu_sample_counts[tp_rank_local + + 1]] + + # Get the pixel values for the local images based on the image_idxs_local + if len(image_idxs_local) > 0: + pixel_values_local = torch.cat([ + pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] + for i in image_idxs_local + ]) + else: + # Handle case where this rank has no images + pixel_values_local = torch.empty((0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype) + # embed_dim_reduction_factor = 2 * 2 + embed_dim_reduction_factor = (vision_model.spatial_merge_size * + vision_model.spatial_merge_size) + + # Find the max length across all ranks + # The output embedding of every DP rank has to be + # padded to this length for tensor_model_parallel_all_gather + # to work + max_len_per_rank = max( + grouped_pixel_values_len) // embed_dim_reduction_factor + local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] + + # Run the vision model on the local pixel_values_local + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model(pixel_values_local, + local_grid_thw_list) + else: + # Handle empty case + image_embeds_local = torch.empty((0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + # Pad the output based on max_len_per_rank + # for tensor_model_parallel_all_gather to work + current_len = image_embeds_local.shape[0] + if current_len < max_len_per_rank: + padding_size = max_len_per_rank - current_len + padding = torch.empty((padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], + dim=0) + else: + image_embeds_local_padded = image_embeds_local + + # Do all_gather to collect embeddings from all ranks + gathered_embeds = tensor_model_parallel_all_gather( + image_embeds_local_padded, dim=0) + + # Remove padding and reconstruct per-rank embeddings + rank_embeddings = list[torch.Tensor]() + for rank in range(tp_size): + start_idx = rank * max_len_per_rank + end_idx = start_idx + (grouped_pixel_values_len[rank] // + embed_dim_reduction_factor) + rank_embeddings.append(gathered_embeds[start_idx:end_idx]) + + patches_per_output_image = [(patch_size // embed_dim_reduction_factor) + for patch_size in patches_per_image] + + # Reconstruct embeddings in the original order + original_order_embeddings = [None] * len(grid_thw_list) + current_idx = 0 + for rank in range(tp_size): + count = gpu_sample_counts[rank] + if count > 0: + # Get images assigned to this rank in shuffled order + # GPU_0 = image_idxs_local [0] + # GPU_1 = image_idxs_local [2, 1, 3] + rank_images = image_to_tp_rank[current_idx:current_idx + count] + + rank_embed = rank_embeddings[rank] + # Split rank embeddings back to individual images + embed_start = 0 + for img_idx in rank_images: + img_patches = patches_per_output_image[img_idx] + original_order_embeddings[img_idx] = rank_embed[ + embed_start:embed_start + img_patches] + embed_start += img_patches + current_idx += count + + out_embeddings = tuple(embed for embed in original_order_embeddings + if embed is not None) + assert len(out_embeddings) == len( + original_order_embeddings), "Found unassigned embeddings" + return out_embeddings + + def fetch_audio( audio_url: str, audio_io_kwargs: Optional[dict[str, Any]] = None, diff --git a/vllm/outputs.py b/vllm/outputs.py index 9784a8894472f..acdb2f89ce735 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -409,7 +409,7 @@ class EmbeddingOutput: Args: embedding: The embedding vector, which is a list of floats. - Its length depends on the hidden dimension of the model. + Its length depends on the hidden dimension of the model. """ embedding: list[float] @@ -447,7 +447,7 @@ class ClassificationOutput: Args: probs: The probability vector, which is a list of floats. - Its length depends on the number of classes. + Its length depends on the number of classes. """ probs: list[float] diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 0b16a8e1d1d8b..5686fae5cd7d1 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -268,7 +268,7 @@ class CpuPlatform(Platform): DEFAULT_MAX_NUM_BATCHED_TOKENS) @classmethod - def get_allowed_cpu_memory_node_list( + def get_allowed_cpu_core_node_list( cls) -> tuple[list[int], list[LogicalCPUInfo]]: assert platform.system() == "Linux" @@ -332,5 +332,10 @@ class CpuPlatform(Platform): supplied model configuration. """ arch = cls.get_cpu_architecture() - return (cls.supports_v1(model_config) and arch - in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM)) + return (cls.supports_v1(model_config) + and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC, + CpuArchEnum.ARM, CpuArchEnum.S390X)) + + @classmethod + def opaque_attention_op(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 321db8287c0f8..5cbb7346436ef 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -350,17 +350,7 @@ class CudaPlatformBase(Platform): return FLEX_ATTENTION_V1 # Backends for V0 engine - if selected_backend == _Backend.FLASHINFER: - logger.info("Using FlashInfer backend.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) - logger.info_once( - "Using HND KV cache layout on V1 engine by default for " - "Blackwell (SM 10.0) GPUs.") - set_kv_cache_layout("HND") - return "vllm.attention.backends.flashinfer.FlashInferBackend" - elif selected_backend == _Backend.XFORMERS: + if selected_backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") return "vllm.attention.backends.xformers.XFormersBackend" elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: @@ -416,10 +406,6 @@ class CudaPlatformBase(Platform): if (fp8_kv_cache and not flash_attn_supports_fp8()): logger.info( "Cannot use FlashAttention backend for FP8 KV cache.") - logger.warning( - "Please use FlashInfer backend with FP8 KV Cache for " - "better performance by setting environment variable " - "VLLM_ATTENTION_BACKEND=FLASHINFER") target_backend = _Backend.XFORMERS except ImportError: logger.info( @@ -456,6 +442,10 @@ class CudaPlatformBase(Platform): def use_custom_allreduce(cls) -> bool: return True + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @@ -495,18 +485,63 @@ class CudaPlatformBase(Platform): return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: fp8_attention = kv_cache_dtype.startswith("fp8") - will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + attention_backend = envs.VLLM_ATTENTION_BACKEND + supported = False - if cls.is_device_capability(100): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() + if model_config is not None and model_config.use_mla: + # Default to CutlassMLA for blackwell, + # FlashMLA otherwise + if attention_backend is None: + if cls.is_device_capability(100): + attention_backend = "CUTLASS_MLA" + else: + attention_backend = "FLASHMLA" + + # Only FlashMLA supports fp8 + if attention_backend == "FLASHMLA": + supported = True + else: + supported = (not fp8_attention) + else: + # Default to FlashAttention + if attention_backend is None: + attention_backend = "FLASH_ATTN_VLLM_V1" + + # All Blackwell backends support fp8 + if cls.is_device_capability(100): + supported = True + elif attention_backend == "FLASH_ATTN_VLLM_V1": + if fp8_attention: + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + else: + supported = True return supported + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4017f1ca7eecb..01f3e2d977bc3 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -81,6 +81,7 @@ class CpuArchEnum(enum.Enum): X86 = enum.auto() ARM = enum.auto() POWERPC = enum.auto() + S390X = enum.auto() OTHER = enum.auto() UNKNOWN = enum.auto() @@ -377,6 +378,8 @@ class Platform: return CpuArchEnum.ARM elif machine.startswith("ppc"): return CpuArchEnum.POWERPC + elif machine == "s390x": + return CpuArchEnum.S390X return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN @@ -506,6 +509,14 @@ class Platform: """ return False + @classmethod + def opaque_attention_op(cls) -> bool: + """ + Returns True if we register attention as one giant opaque custom op + on the current platform + """ + return False + @classmethod def validate_request( cls, @@ -562,12 +573,20 @@ class Platform: raise RuntimeError(f"Unsupported torch distributed backend: {backend}") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: """ Returns if the kv_cache_dtype is supported by the current platform. """ return False + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + """ + Check if the dtype is supported by the current platform. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3ede86e158554..c6d14aa87c7f2 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -171,7 +171,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8", "mxfp4" + "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4" ] @classmethod @@ -411,6 +411,10 @@ class RocmPlatform(Platform): supported_archs = ['gfx94', 'gfx95'] return any(gfx in gcn_arch for gfx in supported_archs) + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( @@ -459,5 +463,26 @@ class RocmPlatform(Platform): return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index ba06abd07f085..d7468d74b021f 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -24,6 +24,8 @@ else: logger = init_logger(__name__) +USE_TPU_COMMONS = False + class TpuPlatform(Platform): _enum = PlatformEnum.TPU @@ -194,13 +196,15 @@ class TpuPlatform(Platform): raise ValueError("Torch XLA does not support per-request seed.") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform TpuPlatform = TpuCommonsPlatform # type: ignore + USE_TPU_COMMONS = True except ImportError: logger.info("tpu_commons not found, using vLLM's TpuPlatform") pass diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 66ebc8ad9d22f..84f4cd7256465 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional import torch import vllm.envs as envs -from vllm.config import CUDAGraphMode from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS @@ -91,26 +90,14 @@ class XPUPlatform(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 64 - # FIXME: Temporarily forcing eager mode - # remove after t.compile support stabilizes. - if (envs.VLLM_USE_V1 and model_config is not None - and not vllm_config.model_config.enforce_eager): - from vllm.config import CompilationLevel - vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 - - # Instances created using VllmConfig() typically have model_config as - # None by default. The modification involves adding a check to prevent - # potential null exceptions check and update model config. - if model_config is not None and model_config.dtype == torch.bfloat16 \ - and not cls.device_support_bf16(): - model_config.dtype = torch.float16 - + # lazy import to avoid circular import + from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config if compilation_config.cudagraph_mode is None or \ compilation_config.cudagraph_mode.max_cudagraph_mode() \ != CUDAGraphMode.NONE: - logger.info("[XPU] CUDA graph is not supported on XPU, " - "disabling cudagraphs.") + logger.info("[XPU] CUDA graph is not supported on XPU, disabling " + "cudagraphs. Fallback to cudagraph_mode=NONE") compilation_config.cudagraph_mode = CUDAGraphMode.NONE # check and update parallel config @@ -161,30 +148,11 @@ class XPUPlatform(Platform): torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) - @classmethod - def device_support_bf16(cls) -> bool: - device_name = cls.get_device_name().lower() - if cls.is_client_gpu_a770(): - logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," - " fallback to float16") - return False - else: - logger.info( - "Device name %s supports bfloat16. Please file an issue " - "if you encounter any accuracy problems with bfloat16.", - device_name) - return True - @classmethod def is_data_center_gpu(cls) -> bool: device_name = cls.get_device_name().lower() return device_name.count("data center gpu") > 0 - @classmethod - def is_client_gpu_a770(cls) -> bool: - device_name = cls.get_device_name().lower() - return device_name.count("a770") > 0 - @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa @@ -196,3 +164,18 @@ class XPUPlatform(Platform): @classmethod def device_count(cls) -> int: return torch.xpu.device_count() + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + device_name = cls.get_device_name().lower() + # client gpu a770 + if device_name.count("a770") > 0: + raise ValueError( + "Intel Arc A770 have bfloat16 accuracy known issue. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") + + @classmethod + def opaque_attention_op(cls) -> bool: + return True diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 4f4522d726e89..df9e84163f16c 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -44,7 +44,7 @@ class ReasoningParser: return self.model_tokenizer.get_vocab() @abstractmethod - def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + def is_reasoning_end(self, input_ids: list[int]) -> bool: """ Check if the reasoning content ends in the input_ids. diff --git a/vllm/sequence.py b/vllm/sequence.py index cbe63f8d1d4e4..36b1b198bd5a5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -16,14 +16,17 @@ import msgspec import torch from vllm.inputs import SingletonInputs -from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.sampling_params import RequestOutputKind, SamplingParams if TYPE_CHECKING: + from vllm.lora.request import LoRARequest from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorOutput) +else: + LoRARequest = Any + KVConnectorOutput = Any VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -144,18 +147,7 @@ class SequenceDataDelta( class SequenceData(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence. - - Args: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. Set to an empty list if - None. - - Attributes: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. - cumulative_logprob: The cumulative log probability of the output. - """ + """Data associated with a sequence.""" # NOTE: we cannot use Union[list, array] because msgspec cannot support # union of 2 list types. _prompt_token_ids: array @@ -253,10 +245,12 @@ class SequenceData(msgspec.Struct, @property def cumulative_logprob(self) -> float: + """The cumulative log probability of the output.""" return self._cumulative_logprob @property def prompt_token_ids(self) -> tuple[int, ...]: + """The token IDs of the prompt.""" return self._prompt_token_ids_tuple @prompt_token_ids.setter @@ -274,6 +268,7 @@ class SequenceData(msgspec.Struct, @property def output_token_ids(self) -> tuple[int, ...]: + """The token IDs of the output.""" return tuple(self._output_token_ids) @output_token_ids.setter @@ -522,9 +517,9 @@ class Sequence: @property def multi_modal_data(self) -> MultiModalKwargs: if self.inputs["type"] == "multimodal": - return self.inputs["mm_kwargs"] + return self.inputs["mm_kwargs"].get_data() - return MultiModalKwargs({}) + return MultiModalKwargs() @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -780,7 +775,7 @@ class SequenceGroup: return self.first_seq.multi_modal_data elif self.encoder_seq is not None: return self.encoder_seq.multi_modal_data - return MultiModalKwargs({}) + return MultiModalKwargs() @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -937,7 +932,7 @@ class SequenceGroupMetadata( omit_defaults=True): # type: ignore[call-arg] """Metadata for a sequence group. Used to create `AttentionMetadata`. - Args: + Attributes: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. seq_data: The sequence data. (Seq id -> sequence data) @@ -947,14 +942,14 @@ class SequenceGroupMetadata( do_sample: True if sampling is required. Sampling is not required when e.g., prefill is chunked, and the current iteration only computes query tokens for prefill, we don't need sampling. - token_chunk_size: The number of tokens to be processed (per sequence). - None if chunking is not required. + pooling_params: Pooling parameters. lora_request: LoRA request. computed_block_nums: The block numbers that are already computed, used in prefix caching. state: Internal state tied to this sequence group. + token_type_ids: Token type IDs. multi_modal_data: Multi modal data. - mm_processor_kwargs: Multimodal input processor / mapper overrides. + multi_modal_placeholders: Multi modal placeholders. encoder_seq_data: Optional sequence data for encoder prompt (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder @@ -1040,12 +1035,13 @@ class SequenceOutput( array_like=True): # type: ignore[call-arg] """The model output associated with a sequence. - Args: + Attributes: parent_seq_id: The ID of the parent sequence (for forking in beam search). output_token: The output token ID. logprobs: The logprobs of the output token. (Token id -> logP(x_i+1 | x_0, ..., x_i)) + output_embed: Optional output embedding tensor. """ parent_seq_id: int output_token: int @@ -1138,7 +1134,7 @@ class IntermediateTensors: """ tensors: dict[str, torch.Tensor] - kv_connector_output: Optional["KVConnectorOutput"] + kv_connector_output: Optional[KVConnectorOutput] def __init__(self, tensors): # manually define this function, so that @@ -1163,7 +1159,13 @@ class IntermediateTensors: return len(self.tensors) def __eq__(self, other: object): - return isinstance(other, self.__class__) and self + if not isinstance(other, self.__class__): + return False + if self.tensors.keys() != other.tensors.keys(): + return False + return all( + torch.equal(self.tensors[k], other.tensors[k]) + for k in self.tensors) def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index e0ef7f0999d47..d09c5fa924fb0 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -20,6 +20,16 @@ def _get_qwen_chat_template_fallback( return CHAT_TEMPLATES_DIR / "template_basic.jinja" +def _get_minicpmv_chat_template_fallback( + tokenizer_name_or_path: str) -> Optional[Path]: + # MiniCPM-V-4.5 version uses a dedicated template + if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path: + return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja" + + # Other versions use chatml template + return CHAT_TEMPLATES_DIR / "template_chatml.jinja" + + # yapf: disable _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", @@ -27,6 +37,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", + "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, } diff --git a/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja new file mode 100644 index 0000000000000..661ebd1cf5c17 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja @@ -0,0 +1,93 @@ +{%- set enable_thinking = enable_thinking | default(false) %} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} + +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} + +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '</think>' in message.content %} + {%- set content = message.content.split('</think>')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '<tool_call>\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n</tool_call>' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n<tool_response>\n' }} + {{- message.content }} + {{- '\n</tool_response>' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '<think>\n\n</think>\n\n' }} + {%- endif %} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '<think>\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d8c964fb2a4a4..bec792465bfbb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -14,7 +14,7 @@ from huggingface_hub import get_safetensors_metadata, hf_hub_download from huggingface_hub import list_repo_files as hf_list_repo_files from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - HFValidationError, LocalEntryNotFoundError, + LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) from transformers import GenerationConfig, PretrainedConfig @@ -27,19 +27,6 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config, - EAGLEConfig, JAISConfig, - KimiVLConfig, MedusaConfig, - MLPSpeculatorConfig, - Nemotron_Nano_VL_Config, - NemotronConfig, OvisConfig, - RWConfig, SpeculatorsConfig, - Step3TextConfig, Step3VLConfig, - UltravoxConfig) -# yapf: enable -from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file if envs.VLLM_USE_MODELSCOPE: @@ -67,24 +54,31 @@ def _get_hf_token() -> Optional[str]: return None -_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { - "chatglm": ChatGLMConfig, - "deepseek_vl_v2": DeepseekVLV2Config, - "kimi_vl": KimiVLConfig, - "Llama_Nemotron_Nano_VL": Nemotron_Nano_VL_Config, - "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) - "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "jais": JAISConfig, - "mlp_speculator": MLPSpeculatorConfig, - "medusa": MedusaConfig, - "eagle": EAGLEConfig, - "speculators": SpeculatorsConfig, - "nemotron": NemotronConfig, - "ovis": OvisConfig, - "ultravox": UltravoxConfig, - "step3_vl": Step3VLConfig, - "step3_text": Step3TextConfig, -} +class LazyConfigDict(dict): + + def __getitem__(self, key): + import vllm.transformers_utils.configs as configs + return getattr(configs, super().__getitem__(key)) + + +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( + chatglm="ChatGLMConfig", + deepseek_vl_v2="DeepseekVLV2Config", + kimi_vl="KimiVLConfig", + Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", + RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) + RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct) + jais="JAISConfig", + mlp_speculator="MLPSpeculatorConfig", + medusa="MedusaConfig", + eagle="EAGLEConfig", + speculators="SpeculatorsConfig", + nemotron="NemotronConfig", + ovis="OvisConfig", + ultravox="UltravoxConfig", + step3_vl="Step3VLConfig", + step3_text="Step3TextConfig", +) _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", @@ -335,6 +329,7 @@ def maybe_override_with_speculators_target_model( gguf_model_repo = Path(model).parent else: gguf_model_repo = None + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model if gguf_model_repo is None else gguf_model_repo, revision=revision, @@ -400,6 +395,7 @@ def get_config( raise ValueError(error_message) from e if config_format == ConfigFormat.HF: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model, revision=revision, @@ -459,6 +455,8 @@ def get_config( model, revision, **kwargs) config_dict["max_position_embeddings"] = max_position_embeddings + from vllm.transformers_utils.configs.mistral import adapt_config_dict + config = adapt_config_dict(config_dict) # Mistral configs may define sliding_window as list[int]. Convert it @@ -503,6 +501,24 @@ def get_config( if quantization_config is not None: config.quantization_config = quantization_config + # auto-enable DeepGEMM UE8M0 on Hopper if model config requests it + scale_fmt = quantization_config.get("scale_fmt", None) + if scale_fmt in ("ue8m0", ): + if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"): + os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1" + logger.info_once( + ("Detected quantization_config.scale_fmt=%s; " + "enabling Hopper UE8M0."), + scale_fmt, + ) + elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.warning_once( + ("Model config requests UE8M0 " + "(quantization_config.scale_fmt=%s), but " + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; " + "Hopper UE8M0 disabled."), + scale_fmt, + ) if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) @@ -532,7 +548,7 @@ def try_get_local_file(model: Union[str, Path], revision=revision) if isinstance(cached_filepath, str): return Path(cached_filepath) - except HFValidationError: + except ValueError: ... return None @@ -908,3 +924,42 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: exc_info=e) return max_position_embeddings + + +def get_model_path(model: Union[str, Path], revision: Optional[str] = None): + if os.path.exists(model): + return model + assert huggingface_hub.constants.HF_HUB_OFFLINE + common_kwargs = { + "local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE, + "revision": revision, + } + + if envs.VLLM_USE_MODELSCOPE: + from modelscope.hub.snapshot_download import snapshot_download + return snapshot_download(model_id=model, **common_kwargs) + + from huggingface_hub import snapshot_download + return snapshot_download(repo_id=model, **common_kwargs) + + +def get_hf_file_bytes(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main') -> Optional[bytes]: + """Get file contents from HuggingFace repository as bytes.""" + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) + + if file_path is None: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision, + token=_get_hf_token()) + file_path = Path(hf_hub_file) + + if file_path is not None and file_path.is_file(): + with open(file_path, 'rb') as file: + return file.read() + + return None diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index bc249c5836034..6aabf9e5262e6 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -61,8 +61,8 @@ class EAGLEConfig(PretrainedConfig): else f"Eagle3{arch}" for arch in self.model.architectures ] else: - raise ValueError(f"Invalid method {method}. \ - Supported methods are eagle and eagle3.") + raise ValueError(f"Invalid method {method}. " + "Supported methods are eagle and eagle3.") super().__init__(**kwargs) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index be1040c3e0147..101f31d39cc1f 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -23,27 +23,32 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. + # Performance improvements: avoid repeated attribute and function lookups; + # localize frequently used objects; + sub_texts: list[str] = [] current_sub_text: list[str] = [] - all_special_tokens = set(tokenizer.all_special_tokens) + convert_tokens_to_string = tokenizer.convert_tokens_to_string + added_vocab_set = set(tokenizer.get_added_vocab()) + all_special_tokens = set( + tokenizer.all_special_tokens) if skip_special_tokens else () + for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: + # Use precomputed set for skip-special check + if token in all_special_tokens: continue - if token in tokenizer.get_added_vocab(): + if token in added_vocab_set: if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] + sub_texts.append(convert_tokens_to_string(current_sub_text)) + current_sub_text.clear() sub_texts.append(token) else: current_sub_text.append(token) if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) + sub_texts.append(convert_tokens_to_string(current_sub_text)) if spaces_between_special_tokens: return " ".join(sub_texts) - else: - return "".join(sub_texts) + return "".join(sub_texts) # 5 is an arbitrary value that should work for all diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index eca4d7c884dd3..8a1ad226d99f0 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -11,5 +11,6 @@ reasons: from vllm.transformers_utils.processors.deepseek_vl2 import ( DeepseekVLV2Processor) from vllm.transformers_utils.processors.ovis import OvisProcessor +from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor -__all__ = ["DeepseekVLV2Processor", "OvisProcessor"] +__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"] diff --git a/vllm/transformers_utils/processors/ovis2_5.py b/vllm/transformers_utils/processors/ovis2_5.py new file mode 100644 index 0000000000000..d3273257ff8c2 --- /dev/null +++ b/vllm/transformers_utils/processors/ovis2_5.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from functools import cached_property +from typing import Optional, Union + +import numpy as np +import PIL +import torch +from transformers import AutoProcessor, BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, + Unpack) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +__all__ = ['Ovis2_5Processor'] +IMAGE_TOKEN = "<image>" +VIDEO_TOKEN = "<video>" +MIN_PIXELS = 448 * 448 +MAX_PIXELS = 1792 * 1792 + + +class Ovis2_5ProcessorKwargs(ProcessingKwargs, + total=False): # type: ignore[call-arg] + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + 'convert_to_rgb': True, + 'min_pixels': MIN_PIXELS, + 'max_pixels': MAX_PIXELS, + }, + "videos_kwargs": { + 'convert_to_rgb': True, + 'min_pixels': MIN_PIXELS, + 'max_pixels': MAX_PIXELS, + } + } + + +class Ovis2_5Processor(ProcessorMixin): + r""" + Constructs a Ovis processor which wraps a Ovis image processor + and a Qwen2 tokenizer into a single processor. + [`OvisProcessor`] offers all the functionalities of + [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. + See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] + for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will + be used to convert lists of messages in a chat into + a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "image_pad_token"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + image_pad_token=None, + patch_size=16, + hidden_stride=2, + temporal_patch_size=1, + **kwargs, + ): + self.image_token = IMAGE_TOKEN + self.video_token = VIDEO_TOKEN + self.image_pad_token = "<|image_pad|>" + + self.patch_size = patch_size + self.hidden_stride = hidden_stride + self.temporal_patch_size = temporal_patch_size + super().__init__(image_processor, + tokenizer, + chat_template=chat_template) + + @cached_property + def extra_special_tokens(self): + image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] + extra_special_tokens = { + "image_token": -200, + "video_token": -201, + "visual_atom": -300, + "image_start": -301, + "image_end": -302, + "video_start": -303, + "video_end": -304, + 'image_pad': image_pad_token_id, + } + return extra_special_tokens + + def __call__( + self, + images: ImageInput = None, + videos: Union[np.ndarray, list[ImageInput]] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], + list[PreTokenizedInput]] = None, + **kwargs: Unpack[Ovis2_5ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) + and image(s). This method forwards the `text`and `kwargs` arguments + to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` + is not `None` to encode the text. To prepare the vision inputs, + this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] + if `vision_infos` is not `None`. + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, + `list[PIL.Image.Image]`, `list[np.ndarray]`, + `list[torch.Tensor]`): + The image or batch of images to be prepared. + Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats + are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. + Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as + list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with + a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, + `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video + can be a 4D NumPy array or PyTorch tensor, or a nested + list of 3D frames. Both channels-first and channels-last + formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. + Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **input_ids** -- list of token ids to be fed to a model. + Returned when `text` is not `None`. + - **attention_mask** -- list of indices specifying which tokens + should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* + is in `self.model_input_names` and if `text` is not `None`). + - **pixel_values** -- Pixel values to be fed to a model. + Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to + a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- list of image 3D grid in LLM. Returned + when `images` is not `None`. + - **video_grid_thw** -- list of video 3D grid in LLM. Returned + when `videos` is not `None`. + - **second_per_grid_ts** -- list of video seconds per time grid. + Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Ovis2_5ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # Process all images first + visual_features = {} + output = BatchFeature() + if images is not None: + processed_images = [] + image_placeholders_list = [] + grids = [] + # Process each image + for image in images if isinstance(images, list) else [images]: + pixel_values, image_placeholders, grid = ( + self.preprocess_multidata( + images=image, **output_kwargs["images_kwargs"])) + processed_images.append(pixel_values) + image_placeholders_list.append(image_placeholders) + grids.append(grid) + + # assign all processed images + if processed_images: + visual_features["image_placeholders"] = image_placeholders_list + output["pixel_values"] = processed_images + output["grids"] = grids + + if videos is not None: + processed_videos = [] + videos_placeholders_list = [] + grids = [] + # Process each video + for video in videos if isinstance(videos, list) else [videos]: + pixel_values, video_placeholders, grid = ( + self.preprocess_multidata( + video=video, **output_kwargs["videos_kwargs"])) + processed_videos.append(pixel_values) + videos_placeholders_list.append(video_placeholders) + grids.append(grid) + # assign all processed videos + if processed_videos: + visual_features[ + "video_placeholders"] = videos_placeholders_list + output["video_pixel_values"] = processed_videos + output["video_grids"] = grids + + # Process text input + if text is not None: + if not isinstance(text, list): + text = [text] + tokenized_batched_text = self._tokenize_with_visual_symbol(text) + image_token_id = self.get_token_value("image_token") + video_token_id = self.get_token_value("video_token") + replaced_ids_list = [] + image_idx = 0 + video_idx = 0 + for ids_tensor in tokenized_batched_text: + has_image_tokens = (image_token_id in ids_tensor + and "image_placeholders" in visual_features + and image_idx < len( + visual_features["image_placeholders"])) + has_video_tokens = (video_token_id in ids_tensor + and "video_placeholders" in visual_features + and video_idx < len( + visual_features["video_placeholders"])) + if has_image_tokens or has_video_tokens: + # Convert to list for easier manipulation + ids_list = ids_tensor.tolist() + new_ids = [] + + # Replace placeholders + for token_id in ids_list: + if token_id == image_token_id: + new_ids.extend( + visual_features["image_placeholders"] + [image_idx]) + image_idx += 1 + elif token_id == video_token_id: + new_ids.extend( + visual_features["video_placeholders"] + [video_idx]) + video_idx += 1 + else: + new_ids.append(token_id) + # Convert back to tensor + ids_tensor = torch.tensor(new_ids, dtype=torch.long) + replaced_ids_list.append(ids_tensor) + if replaced_ids_list: + replaced_and_tokenized_ids = torch.stack(replaced_ids_list) + else: + replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long) + output["input_ids"] = replaced_and_tokenized_ids + + return output + # If only images were provided + return BatchFeature(data=visual_features) + + def _tokenize_with_visual_symbol(self, + text_list: list[str]) -> torch.LongTensor: + batch_token_ids = [] + for text in text_list: + token_ids = [] + video_token_id = self.get_token_value("video_token") + image_token_id = self.get_token_value("image_token") + video_split_texts = text.split(self.video_token) + + for j, video_segment in enumerate(video_split_texts): + image_split_texts = video_segment.split(self.image_token) + text_chunks = [ + self.tokenizer(chunk, add_special_tokens=False).input_ids + for chunk in image_split_texts + ] + segment_tokens = [] + for i, chunk in enumerate(text_chunks): + segment_tokens.extend(chunk) + if i < len(text_chunks) - 1: + segment_tokens.append(image_token_id) + token_ids.extend(segment_tokens) + if j < len(video_split_texts) - 1: + token_ids.append(video_token_id) + + batch_token_ids.append(token_ids) + return torch.tensor(batch_token_ids, dtype=torch.long) + + # Copied from qwen2_vl + def smart_resize(self, + height: int, + width: int, + factor: int = 28, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS): + """Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range + ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if height < factor or width < factor: + print(f"height:{height} or width:{width} must be " + f"larger than factor:{factor}") + if height < width: + width = round(factor / height * width) + height = factor + else: + height = round(factor / width * height) + width = factor + + elif max(height, width) / min(height, width) > 200: + print(f"absolute aspect ratio must be smaller than 200, " + f"got {max(height, width) / min(height, width)}") + if height > width: + height = 200 * width + else: + width = 200 * height + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + def get_token_value(self, tok): + return self.extra_special_tokens[tok] + + def construct_visual_indicators(self, grid, is_video: bool = False): + if is_video: + start_token = self.get_token_value('video_start') + end_token = self.get_token_value('video_end') + else: + start_token = self.get_token_value('image_start') + end_token = self.get_token_value('image_end') + + image_placeholders = [start_token, self.get_token_value('visual_atom')] + if grid[0] * grid[1] > 1: + for r in range(grid[0]): + for c in range(grid[1]): + image_placeholders.append( + self.get_token_value('visual_atom')) + + image_placeholders.append(end_token) + return image_placeholders + + def construct_visual_placeholders(self, grid, is_video: bool = False): + visual_placeholders = self.construct_visual_indicators((1, 1), + is_video) + + image_atom_token_id = self.get_token_value('visual_atom') + # Extract the padding token ID from tokenizer + image_padding_token_id = self.get_token_value('image_pad') + + num_image_atoms = grid[0] * grid[1] * grid[2] + num_image_atoms //= self.hidden_stride**2 + num_image_atoms //= self.temporal_patch_size + + # Create a new list with padding tokens inserted + padded_placeholder_tokens = [] + for token in visual_placeholders: + if token == image_atom_token_id: + padded_placeholder_tokens.extend([image_padding_token_id] * + num_image_atoms) + else: + padded_placeholder_tokens.append(image_padding_token_id) + return padded_placeholder_tokens + + def preprocess_multidata( + self, + images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None, + video: Optional[Union[list[PIL.Image.Image], np.ndarray]] = None, + convert_to_rgb: Optional[bool] = True, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + return_tensors: Optional[str] = 'pt', + ): + is_video = False + if images is not None: + if not isinstance(images, list): + images = [images] + elif video is not None: + is_video = True + # type of vidoe in dummy_mm_data is np.ndarray + if isinstance(video, np.ndarray): + images = [] + for i in range(video.shape[0]): + image = PIL.Image.fromarray(video[i].astype(np.uint8)) + images.append(image) + elif isinstance(video, list): + images = video + min_pixels = min(max_pixels if max_pixels is not None else MAX_PIXELS, + min_pixels if min_pixels is not None else MIN_PIXELS) + images = [ + image.convert("RGB") + if convert_to_rgb and image.mode != 'RGB' else image + for image in images + ] + + width, height = images[0].size + resized_height, resized_width = height, width + processed_images = [] + for image in images: + resized_height, resized_width = self.smart_resize( + height, + width, + factor=self.patch_size * self.hidden_stride, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + new_size = dict(height=resized_height, width=resized_width) + image_pt = self.image_processor.preprocess( + image, size=new_size, return_tensors="np")['pixel_values'][0] + + processed_images.append(image_pt) + + patches = np.array(processed_images) + if patches.shape[0] % self.temporal_patch_size != 0: + num_to_pad = self.temporal_patch_size - (patches.shape[0] % + self.temporal_patch_size) + repeats = np.repeat(patches[-1][np.newaxis], num_to_pad, axis=0) + patches = np.concatenate([patches, repeats], axis=0) + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h = resized_height // self.patch_size + grid_w = resized_width // self.patch_size + + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.hidden_stride, + self.hidden_stride, + self.patch_size, + grid_w // self.hidden_stride, + self.hidden_stride, + self.patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, channel * self.temporal_patch_size * + self.patch_size * self.patch_size) + + visual_placeholders = self.construct_visual_placeholders( + [grid_t, grid_h, grid_w], is_video) + return torch.tensor( + flatten_patches), visual_placeholders, torch.tensor( + [[grid_t, grid_h, grid_w]]) + + +AutoProcessor.register("Ovis2_5Processor", Ovis2_5Processor) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index d2be2ceeeae6d..b3f1977f26cf4 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -7,7 +7,6 @@ import os import warnings from functools import lru_cache from pathlib import Path -from types import MethodType from typing import TYPE_CHECKING, Any, Optional, Union import huggingface_hub @@ -50,12 +49,11 @@ def decode_tokens( `skip_special_tokens=None` means to use the backend's default settings. """ - decode_method = getattr(tokenizer, "_decode", tokenizer.decode) if skip_special_tokens is not None: - return decode_method(token_ids, - skip_special_tokens=skip_special_tokens) + return tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) - return decode_method(token_ids) + return tokenizer.decode(token_ids) def encode_tokens( @@ -144,26 +142,6 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: return cached_tokenizer -def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: - """Patch _pad method to accept `padding_side` for older tokenizers.""" - orig_pad = tokenizer._pad - - def _pad( - self: PreTrainedTokenizer, - *args, - padding_side: Optional[str] = None, - **kwargs, - ): - if padding_side is not None and padding_side != self.padding_side: - msg = ("`padding_side` argument is not supported by " - f"{type(tokenizer).__name__} and will be ignored.") - warnings.warn(msg, stacklevel=2) - - return orig_pad(*args, **kwargs) - - tokenizer._pad = MethodType(_pad, tokenizer) - - def get_tokenizer( tokenizer_name: Union[str, Path], *args, @@ -271,12 +249,6 @@ def get_tokenizer( } tokenizer.add_special_tokens(special_tokens_map) - # NOTE: We can remove this after https://github.com/zai-org/ChatGLM3/issues/1324 - if type(tokenizer).__name__ in ("ChatGLMTokenizer", - "ChatGLM4Tokenizer"): - assert isinstance(tokenizer, PreTrainedTokenizer) - patch_padding_side(tokenizer) - if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( "Using a slow tokenizer. This might cause a significant " diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index a1f8ad164762d..60bddc5b500b5 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -516,8 +516,8 @@ def random_uuid() -> str: class AsyncMicrobatchTokenizer: """Asynchronous tokenizer with micro-batching. - Pulls pending encode/decode requests from a queue and batches them - up to reduce overhead. A single-thread ThreadPoolExecutor is used + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used so the event loop stays responsive. """ @@ -664,18 +664,18 @@ class AsyncMicrobatchTokenizer: def _queue_key(self, op: str, kwargs: dict) -> tuple: """ Return a normalized key describing operation + kwargs. - + - `add_special_tokens`: {True/False} - `truncation`: {True/False} - - If `truncation` is False (`max_length` is None), + - If `truncation` is False (`max_length` is None), returns a key for a can_batch queue. - If `truncation` is True and `max_length` is None or equals `tokenizer.model_max_length`, returns a key for a can_batch queue. - Otherwise, returns a key for a cannot_batch queue. - + Examples: - Decode: ("decode",) - - Encode typical: + - Encode typical: ("encode", add_special_tokens, bool_truncation, max_length_label) - Fallback: ("encode", "other") """ @@ -940,6 +940,14 @@ def get_open_port() -> int: return _get_open_port() +def get_open_ports_list(count: int = 5) -> list[int]: + """Get a list of open ports.""" + ports = set() + while len(ports) < count: + ports.add(get_open_port()) + return list(ports) + + def _get_open_port() -> int: port = envs.VLLM_PORT if port is not None: @@ -1315,6 +1323,11 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): ) +def as_list(maybe_list: Iterable[T]) -> list[T]: + """Convert iterable to list, unless it's already a list.""" + return maybe_list if isinstance(maybe_list, list) else list(maybe_list) + + # `collections` helpers def is_list_of( value: object, @@ -1427,6 +1440,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None: torch.cuda.set_stream = _patched_set_stream +class _StreamPlaceholder: + + def __init__(self): + self.synchronize = lambda: None + + def current_stream() -> torch.cuda.Stream: """ replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. @@ -1446,8 +1465,18 @@ def current_stream() -> torch.cuda.Stream: # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process - _current_stream_tls.value = torch.cuda.Stream( - ) if current_platform.is_rocm() else torch.cuda.current_stream() + if current_platform.is_rocm(): + _current_stream_tls.value = torch.cuda.Stream() + elif current_platform.is_cpu(): + _current_stream_tls.value = _StreamPlaceholder() + else: + current_stream = current_platform.current_stream + if current_stream is not None: + _current_stream_tls.value = current_stream() + else: + raise ValueError( + "Fail to set current stream, current platform " + "may not support current_stream with torch API") return _current_stream_tls.value @@ -1640,15 +1669,19 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: return weak_bound -# From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: - if not wrapper.has_run: # type: ignore[attr-defined] - wrapper.has_run = True # type: ignore[attr-defined] - return f(*args, **kwargs) + if wrapper.has_run: # type: ignore[attr-defined] + return + + with wrapper.lock: # type: ignore[attr-defined] + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) wrapper.has_run = False # type: ignore[attr-defined] + wrapper.lock = threading.Lock() # type: ignore[attr-defined] return wrapper @@ -1941,7 +1974,7 @@ class FlexibleArgumentParser(ArgumentParser): file_path = args[index + 1] - config_args = self._load_config_file(file_path) + config_args = self.load_config_file(file_path) # 0th index is for {serve,chat,complete} # optionally followed by model_tag (only for serve) @@ -1972,7 +2005,7 @@ class FlexibleArgumentParser(ArgumentParser): return args - def _load_config_file(self, file_path: str) -> list[str]: + def load_config_file(self, file_path: str) -> list[str]: """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml @@ -2013,6 +2046,11 @@ class FlexibleArgumentParser(ArgumentParser): if isinstance(value, bool) and key not in store_boolean_arguments: if value: processed_args.append('--' + key) + elif isinstance(value, list): + if value: + processed_args.append('--' + key) + for item in value: + processed_args.append(str(item)) else: processed_args.append('--' + key) processed_args.append(str(value)) @@ -2449,7 +2487,7 @@ class PlaceholderModule(_PlaceholderBase): A placeholder object to use when a module does not exist. This enables more informative errors when trying to access attributes - of a module that does not exists. + of a module that does not exist. """ def __init__(self, name: str) -> None: @@ -2553,7 +2591,7 @@ def direct_register_custom_op( def resolve_obj_by_qualname(qualname: str) -> Any: """ - Resolve an object by its fully qualified name. + Resolve an object by its fully-qualified class name. """ module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) @@ -3076,7 +3114,7 @@ class LazyLoader(types.ModuleType): """ LazyLoader module borrowed from Tensorflow https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py - with a addition of "module caching". + with an addition of "module caching". Lazily import a module, mainly to avoid pulling in large dependencies. Modules such as `xgrammar` might do additional side effects, so we @@ -3243,6 +3281,24 @@ def sha256_cbor_64bit(input) -> int: return full_hash & ((1 << 64) - 1) +def get_hash_fn_by_name(hash_fn_name: str) -> Callable: + """Get a hash function by name, or raise an error if + the function is not found. + Args: + hash_fn_name: Name of the hash function. + Returns: + A hash function. + """ + if hash_fn_name == "sha256": + return sha256 + if hash_fn_name == "sha256_cbor_64bit": + return sha256_cbor_64bit + if hash_fn_name == "builtin": + return hash + + raise ValueError(f"Unsupported hash function: {hash_fn_name}") + + def is_torch_equal_or_newer(target: str) -> bool: """Check if the installed torch version is >= the target version. diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 861d9c0c0005d..cd1dbfb813fee 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -27,41 +27,37 @@ def is_deep_gemm_supported() -> bool: is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) or current_platform.is_device_capability(100)) - return has_deep_gemm() and is_supported_arch + return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch @functools.cache -def is_blackwell_deep_gemm_e8m0_used() -> bool: +def is_deep_gemm_e8m0_used() -> bool: """Return ``True`` if vLLM is configured to use DeepGEMM " - "E8M0 scale on a Blackwell-class GPU. + "E8M0 scale on a Hopper or Blackwell-class GPU. """ - if not (envs.VLLM_USE_DEEP_GEMM): - logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0.") - return False - - if not has_deep_gemm(): - logger.debug_once("DeepGEMM E8M0 disabled: DeepGEMM backend missing.") - return False - - if not envs.VLLM_USE_DEEP_GEMM_E8M0: - logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.") + if not is_deep_gemm_supported(): + logger.info_once( + "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.") return False _lazy_init() if _fp8_gemm_nt_impl is None: - logger.debug_once( - "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") + logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False - enabled = (current_platform.is_cuda() - and current_platform.has_device_capability(100)) - if enabled: - logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.") - else: - logger.debug_once( - "DeepGEMM E8M0 disabled: not running on Blackwell GPU.") - return enabled + if current_platform.is_device_capability(100) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.") + return True + + if current_platform.is_device_capability(90) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.") + return True + + logger.info_once("DeepGEMM E8M0 disabled on current configuration.") + return False def _missing(*_: Any, **__: Any) -> NoReturn: @@ -127,20 +123,18 @@ def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) - return _fp8_gemm_nt_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + return _fp8_gemm_nt_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + return _grouped_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): @@ -148,9 +142,7 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) def _ceil_to_ue8m0(x: torch.Tensor): @@ -202,12 +194,19 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim +def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, + weight: torch.Tensor): + return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + + __all__ = [ "calc_diff", "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "per_block_cast_to_fp8", - "is_blackwell_deep_gemm_e8m0_used", + "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", + "should_use_deepgemm_for_fp8_linear", ] diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 0d7d4b694f076..fab134733d4fd 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -132,6 +132,11 @@ def has_nvidia_artifactory() -> bool: This checks connectivity to the kernel inference library artifactory which is required for downloading certain cubin kernels like TRTLLM FHMA. """ + # Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when + # it's true, we could assume the cubins are available. + if envs.VLLM_HAS_FLASHINFER_CUBIN: + return True + try: # Use a short timeout to avoid blocking for too long response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5) @@ -148,33 +153,17 @@ def has_nvidia_artifactory() -> bool: return False -def use_trtllm_attention( - num_tokens: int, - max_seq_len: int, - kv_cache_dtype: str, - num_qo_heads: Optional[int], - num_kv_heads: Optional[int], - attn_head_size: Optional[int], - has_sinks: bool = False, -) -> bool: +@functools.cache +def supports_trtllm_attention() -> tuple[bool, Optional[str]]: + """Cache result which only depends on the environment""" + # This is a lambda, call it once + env_value = envs.VLLM_USE_TRTLLM_ATTENTION + # Requires SM100 and NVIDIA artifactory to be accessible to download cubins if not (current_platform.is_device_capability(100) and has_nvidia_artifactory()): - return False + return False, env_value - # Check if the dimensions are supported by TRTLLM decode attention - if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None - or num_qo_heads % num_kv_heads != 0): - return False - - # If sinks are being used, we must use TRTLLM attention as it's - # the only backend that supports them - if has_sinks: - logger.info_once( - "Using TRTLLM attention (required for attention sinks).") - return True - - env_value = envs.VLLM_USE_TRTLLM_ATTENTION if env_value is not None: logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) # Environment variable is set - respect it @@ -184,8 +173,46 @@ def use_trtllm_attention( use_trtllm = (env_value == "1") if use_trtllm: logger.info_once("Using TRTLLM attention.") - return use_trtllm - else: + return use_trtllm, env_value + + return True, None + + +def use_trtllm_attention( + num_qo_heads: int, + num_kv_heads: int, + num_tokens: int, + max_seq_len: int, + kv_cache_dtype: str, + q_dtype: torch.dtype, + is_prefill: bool, + has_sinks: bool = False, +) -> bool: + use_trtllm, env_value = supports_trtllm_attention() + if not use_trtllm: + return False + + if num_qo_heads % num_kv_heads != 0: + return False + + # Must use TRTLLM attention if query is FP8 quantized + if q_dtype == current_platform.fp8_dtype(): + logger.info_once("Using TRTLLM attention (query is quantized).") + return True + + # TRTLLM prefill attention does not support FP8 kv cache with + # non-quantized query + if is_prefill and kv_cache_dtype.startswith("fp8"): + return False + + # If sinks are being used, we must use TRTLLM attention as it's + # the only backend that supports them + if has_sinks: + logger.info_once( + "Using TRTLLM attention (required for attention sinks).") + return True + + if env_value is None: # Environment variable not set - use auto-detection use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 and kv_cache_dtype == "auto") @@ -193,6 +220,9 @@ def use_trtllm_attention( logger.warning_once("Using TRTLLM attention (auto-detected).") return use_trtllm + # Environment variable is set to 1 - respect it + return True + if has_flashinfer(): @@ -235,6 +265,37 @@ if has_flashinfer(): dtype=dtype, device=A.device) + @torch.library.custom_op( + "vllm::bmm_fp8", + mutates_args=[], + device_types="cuda", + ) + def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import bmm_fp8 as bmm_fp8_ + return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend) + + @torch.library.register_fake("vllm::bmm_fp8", ) + def bmm_fp8_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], + A.shape[1], + B.shape[2], + dtype=dtype, + device=A.device) + def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, @@ -263,6 +324,35 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, ) +def flashinfer_scaled_fp8_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert a.shape[1] == b.shape[0] + assert scale_a.numel() == 1 and scale_b.numel() == 1 + assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn + assert a.device.type == "cuda" and b.device.type == "cuda" + assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32 + assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda" + + output = bmm_fp8( + a.unsqueeze(0), + b.unsqueeze(0), + scale_a, + scale_b, + out_dtype, + "auto", + ).view(a.shape[0], b.shape[1]) + + if bias is not None: + output = output + bias + return output + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -274,6 +364,8 @@ __all__ = [ "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", + "supports_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", + "flashinfer_scaled_fp8_mm", ] diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 9ed46331863c9..973979fdf7dfd 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): attn_metadata: TorchSDPAMetadata, # type: ignore output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -497,7 +498,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TorchSDPABackendImpl") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ab7a71a399b34..6e7096de924ca 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -233,7 +233,7 @@ class FlashAttentionMetadataBuilder( num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu @@ -405,13 +405,6 @@ class FlashAttentionImpl(AttentionImpl): FlashAttentionBackend.validate_head_size(head_size) - if attn_type not in [ - AttentionType.DECODER, AttentionType.ENCODER_ONLY - ]: - raise NotImplementedError("Encoder/decoder cross-attention " - "is not implemented for " - "FlashAttentionImpl") - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -437,6 +430,7 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -454,7 +448,7 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") @@ -477,7 +471,7 @@ class FlashAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens # Handle encoder attention differently - no KV cache needed - if attn_type in (AttentionType.ENCODER_ONLY, ): + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention(query[:num_actual_tokens], @@ -489,7 +483,11 @@ class FlashAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if (self.kv_sharing_target_layer_name is None and key is not None + and value is not None): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -528,7 +526,7 @@ class FlashAttentionImpl(AttentionImpl): block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) flash_attn_varlen_func( q=query[:num_actual_tokens], diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 02decb171fc05..1115fc606b055 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -6,21 +6,27 @@ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar, Optional, Union +import numpy as np import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) -from flashinfer.decode import (_get_range_buf, get_seq_lens, - trtllm_batch_decode_with_kv_cache) +from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache +from flashinfer.utils import FP4Tensor -import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import use_trtllm_attention +from vllm.utils.flashinfer import (supports_trtllm_attention, + use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block # yapf: disable @@ -31,10 +37,14 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills) +# yapf: enable from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + logger = init_logger(__name__) @@ -115,35 +125,6 @@ class FlashInferMetadata: num_actual_tokens: int # Number of tokens excluding padding. - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - qo_indptr_cpu: torch.Tensor - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan) - paged_kv_indptr_cpu: torch.Tensor - # The page indices of the paged kv cache (on device for plan) - paged_kv_indices: torch.Tensor - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] (CPU for plan) - paged_kv_last_page_len_cpu: torch.Tensor - # The number of query/output heads - num_qo_heads: int - # The number of key/value heads - num_kv_heads: int - # The dimension of the attention heads - head_dim: int - # Block size of vllm - page_size: int - # The data type of the paged kv cache - kv_data_type: torch.dtype # The data type of the query q_data_type: torch.dtype @@ -165,10 +146,6 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - shared_qo_indptr_cpu: Optional[torch.Tensor] = None - shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None - shared_kv_page_indices_cpu: Optional[torch.Tensor] = None - shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None @@ -177,10 +154,6 @@ class FlashInferMetadata: qo_indptr_gpu: Optional[torch.Tensor] = None paged_kv_indptr_gpu: Optional[torch.Tensor] = None - def __post_init__(self): - if self.head_dim is not None: - FlashInferBackend.validate_head_size(self.head_dim) - class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ @@ -193,13 +166,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.device = device self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + max_num_pages_per_req = cdiv(self.model_config.max_model_len, self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req @@ -213,12 +187,37 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self._decode_cudagraph_max_bs = min( max_num_reqs, self.compilation_config.max_capture_size) + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config) + self.num_kv_heads = self.kv_cache_spec.num_kv_heads + self.head_dim = self.kv_cache_spec.head_size + FlashInferBackend.validate_head_size(self.head_dim) + self.page_size = self.kv_cache_spec.block_size + + self.enable_fusion = ( + self.compilation_config.pass_config.enable_attn_fusion) + self.q_data_type = self.model_config.dtype + self.cache_dtype = self.cache_config.cache_dtype + if self.cache_dtype.startswith("fp8"): + self.kv_cache_dtype = ( + FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.cache_dtype)) + # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled + if self.enable_fusion: + self.q_data_type = self.kv_cache_dtype + else: + self.kv_cache_dtype = self.kv_cache_spec.dtype + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + self.sm_scale = self.global_hyperparameters.sm_scale + self.window_left = self.global_hyperparameters.window_left + self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap + self.has_sinks = self.global_hyperparameters.has_sinks # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, @@ -237,6 +236,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.int32, device="cpu", pin_memory=pin_memory) + self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() self.paged_kv_indices_cpu = torch.zeros(max_num_pages, dtype=torch.int32, device="cpu", @@ -245,14 +245,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.int32, device="cpu", pin_memory=pin_memory) - - self.block_table_arange = torch.arange(max_num_pages_per_req, - dtype=torch.int32, - device=self.device) + self.paged_kv_last_page_len_np = ( + self.paged_kv_last_page_len_cpu.numpy()) def _get_workspace_buffer(self): if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( + self._workspace_buffer = torch.zeros( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device) @@ -274,14 +272,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): decode_wrapper = self._decode_wrapper if decode_wrapper is None: - num_qo_heads = ( - self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config)) - num_kv_heads = self.vllm_config.model_config.get_num_kv_heads( - self.vllm_config.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - if use_cudagraph: paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] paged_kv_indices = self.paged_kv_indices @@ -298,7 +288,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indptr_buffer=paged_kv_indptr, paged_kv_indices_buffer=paged_kv_indices, paged_kv_last_page_len_buffer=paged_kv_last_page_len, - use_tensor_cores=use_tensor_cores) + # Tensor cores are enabled by default because the perf would be + # atleast as good as cuda cores for all attention ops in latest + # gpus. + use_tensor_cores=True, + ) # save the decode wrapper if use_cudagraph: @@ -314,133 +308,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): 2, self._get_workspace_buffer(), get_kv_cache_layout()) return self._cascade_wrapper - def _plan(self, attn_metadata: FlashInferMetadata): - if attn_metadata.use_cascade: - attn_metadata.cascade_wrapper = self._get_cascade_wrapper() - attn_metadata.cascade_wrapper.plan( - [ - attn_metadata.shared_qo_indptr_cpu, - attn_metadata.qo_indptr_cpu - ], - [ - attn_metadata.shared_kv_page_indptr_cpu, - attn_metadata.paged_kv_indptr_cpu - ], - [ - attn_metadata.shared_kv_page_indices_cpu, - attn_metadata.paged_kv_indices - ], - [ - attn_metadata.shared_kv_last_page_len_cpu, - attn_metadata.paged_kv_last_page_len_cpu - ], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, - ) - else: - # Regular attention (common case). - # Decodes are at the front and prefills are at the back, - # according to reorder_batch() - num_prefills = attn_metadata.num_prefills - num_decodes = attn_metadata.num_decodes - if num_prefills > 0: - # Decodes are first so prefills start after the last decode - prefill_start = num_decodes - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert attn_metadata.paged_kv_last_page_len_cpu[ - prefill_start:].shape[0] == num_prefills - # Since prefill_wrapper.run() will be called with - # query[num_decode_tokens:] we need to adjust the qo_indptr - # to be relative to the start of the prefill queries. - qo_indptr_cpu = attn_metadata.qo_indptr_cpu[ - prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start] - paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[ - prefill_start:] - if not attn_metadata.prefill_use_trtllm: - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - attn_metadata.paged_kv_indices, - attn_metadata. - paged_kv_last_page_len_cpu[prefill_start:], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, - ) - else: - attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) - attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( - self.device) - - if num_decodes > 0: - pure_decode = num_prefills == 0 - # possible required padding for cudagraph replay - use_cudagraph = (self.enable_cuda_graph and pure_decode and - num_decodes <= self._decode_cudagraph_max_bs) - if use_cudagraph: - num_input_tokens = ( - self.vllm_config.pad_for_cudagraph(num_decodes)) - # Carefully fulfill the padding region with reasonable value - # on cpu. - # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[1 + num_decodes:1 + - num_input_tokens].fill_( - attn_metadata. - paged_kv_indptr_cpu[-1]) - # Fill the remaining paged_kv_last_page_len_cpu with 1. - # This is because flashinfer treats 0 as a full page - # instead of empty. - self.paged_kv_last_page_len_cpu[ - num_decodes:num_input_tokens].fill_(1) - - else: - num_input_tokens = num_decodes - - attn_metadata.decode_wrapper = self._get_decode_wrapper( - num_input_tokens, use_cudagraph) - if not attn_metadata.decode_use_trtllm: - # Use the persistent buffer with padding length, - # instead of the same address but chunked version - # in atten_metadata when using cudagraph. - fast_plan_decode( - attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[:num_input_tokens + 1], - attn_metadata.paged_kv_indices, - self.paged_kv_last_page_len_cpu[:num_input_tokens], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, - ) - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -450,14 +317,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata) - page_size = self.kv_cache_spec.block_size + page_size = self.page_size max_q_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.seq_lens_cpu.max() + max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() block_table_tensor = common_attn_metadata.block_table_tensor - block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size + num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size use_cascade = common_prefix_len > 0 if use_cascade: @@ -480,75 +348,63 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] - block_table_bounds_cpu -= num_common_kv_blocks + num_blocks_np -= num_common_kv_blocks else: shared_qo_indptr_cpu = None shared_kv_page_indptr_cpu = None shared_kv_page_indices_cpu = None shared_kv_last_page_len_cpu = None - max_num_blocks = block_table_bounds_cpu.max() - block_table_bounds = block_table_bounds_cpu.to(self.device, - non_blocking=True) - mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) - < block_table_bounds.unsqueeze(1)) - # write self.paged_kv_indices inplace - num_actual_pages = torch.sum(mask) - paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - torch.masked_select(block_table_tensor[:, :max_num_blocks], - mask, - out=paged_kv_indices) - # write self.paged_kv_indptr_cpu inplace (0-index is always 0) - torch.cumsum(block_table_bounds_cpu, - dim=0, - dtype=torch.int32, - out=self.paged_kv_indptr_cpu[1:1 + num_reqs]) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr_np[1:num_reqs + 1], + ) + paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] + paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1], + non_blocking=True) + + # write self.paged_kv_indices inplace + num_actual_pages = num_blocks_np.sum().item() + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + _copy_page_indices_kernel[(num_reqs, )]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) - paged_kv_last_page_len_cpu = seq_lens_cpu % page_size # write self.paged_kv_last_page_len_cpu inplace - torch.where(paged_kv_last_page_len_cpu == 0, - torch.tensor(page_size), - paged_kv_last_page_len_cpu, - out=self.paged_kv_last_page_len_cpu[:num_reqs]) - - cache_dtype = self.cache_config.cache_dtype - if cache_dtype.startswith("fp8"): - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - cache_dtype) - else: - kv_cache_dtype = self.kv_cache_spec.dtype - - num_qo_heads = self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) - num_kv_heads = self.kv_cache_spec.num_kv_heads - head_dim = self.kv_cache_spec.head_size + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len_np[:num_reqs] = np.where( + paged_kv_last_page_len_np == 0, + page_size, + paged_kv_last_page_len_np, + ) # Check if any layer uses sinks (requires TRTLLM attention) - has_sinks = self.global_hyperparameters.has_sinks - - # currently prefill trtllm attention does not support fp8 kv cache - prefill_use_trtllm = not cache_dtype.startswith("fp8") \ - and use_trtllm_attention( - num_prefill_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim, has_sinks) - decode_use_trtllm = use_trtllm_attention( - num_decode_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim, has_sinks) + prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, + self.num_kv_heads, + num_prefill_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=True, + has_sinks=self.has_sinks) + decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, + self.num_kv_heads, + num_decode_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=False, + has_sinks=self.has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, - qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, - paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs], - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len_cpu=self. - paged_kv_last_page_len_cpu[:num_reqs], - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - page_size=page_size, - kv_data_type=kv_cache_dtype, - q_data_type=self.vllm_config.model_config.dtype, + q_data_type=self.q_data_type, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, max_seq_len=max_seq_len, @@ -561,14 +417,121 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, - shared_qo_indptr_cpu=shared_qo_indptr_cpu, - shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu, - shared_kv_page_indices_cpu=shared_kv_page_indices_cpu, - shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu, ) - self._plan(attn_metadata) + qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs] + paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] + if attn_metadata.use_cascade: + attn_metadata.cascade_wrapper = self._get_cascade_wrapper() + attn_metadata.cascade_wrapper.plan( + [shared_qo_indptr_cpu, qo_indptr_cpu], + [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], + [shared_kv_page_indices_cpu, paged_kv_indices], + [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + else: + # Regular attention (common case). + # Decodes are at the front and prefills are at the back, + # according to reorder_batch() + num_prefills = attn_metadata.num_prefills + num_decodes = attn_metadata.num_decodes + if num_prefills > 0: + # Decodes are first so prefills start after the last decode + prefill_start = num_decodes + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + assert qo_indptr_cpu[prefill_start:].shape[ + 0] == num_prefills + 1 + assert paged_kv_indptr_cpu[prefill_start:].shape[ + 0] == num_prefills + 1 + assert paged_kv_last_page_len_cpu[prefill_start:].shape[ + 0] == num_prefills + # Since prefill_wrapper.run() will be called with + # query[num_decode_tokens:] we need to adjust the qo_indptr + # to be relative to the start of the prefill queries. + qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ + prefill_start] + paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] + if not attn_metadata.prefill_use_trtllm: + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + else: + attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) + attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( + self.device) + + if num_decodes > 0: + pure_decode = num_prefills == 0 + # possible required padding for cudagraph replay + use_cudagraph = (self.enable_cuda_graph and pure_decode and + num_decodes <= self._decode_cudagraph_max_bs) + if use_cudagraph: + num_input_tokens = ( + self.vllm_config.pad_for_cudagraph(num_decodes)) + # Carefully fulfill the padding region with reasonable value + # on cpu. + # Make sure paged_kv_indptr_cpu is not decreasing + self.paged_kv_indptr_cpu[1 + num_decodes:1 + + num_input_tokens].fill_( + paged_kv_indptr_cpu[-1]) + # Fill the remaining paged_kv_last_page_len_cpu with 1. + # This is because flashinfer treats 0 as a full page + # instead of empty. + self.paged_kv_last_page_len_cpu[ + num_decodes:num_input_tokens].fill_(1) + + else: + num_input_tokens = num_decodes + + attn_metadata.decode_wrapper = self._get_decode_wrapper( + num_input_tokens, use_cudagraph) + if not attn_metadata.decode_use_trtllm: + # Use the persistent buffer with padding length, + # instead of the same address but chunked version + # in atten_metadata when using cudagraph. + fast_plan_decode( + attn_metadata.decode_wrapper, + self.paged_kv_indptr_cpu[:num_input_tokens + 1], + paged_kv_indices, + self.paged_kv_last_page_len_cpu[:num_input_tokens], + seq_lens_cpu[:num_input_tokens], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) return attn_metadata def build_for_cudagraph_capture( @@ -622,6 +585,8 @@ class FlashInferImpl(AttentionImpl): self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) + self.window_left = (self.sliding_window[0] + if self.sliding_window is not None else -1) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -640,13 +605,20 @@ class FlashInferImpl(AttentionImpl): raise ValueError( "Sinks must have the same number of heads as the number of " f"heads in the layer. Expected {num_heads}, but got " - f"{sinks.shape[0]}." - ) - # Cast sinks to float32 if needed (FlashInfer requirement) - if sinks.dtype != torch.float32: - sinks = sinks.to(torch.float32) + f"{sinks.shape[0]}.") self.sinks = sinks + self.support_trtllm_attn = (supports_trtllm_attention() + and num_heads % num_kv_heads == 0) + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None + self.o_sf_scale: Optional[float] = None + + def fused_output_quant_supported(self, quant_key: QuantKey): + return (self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) + def forward( self, layer: torch.nn.Module, @@ -657,6 +629,7 @@ class FlashInferImpl(AttentionImpl): attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -675,15 +648,56 @@ class FlashInferImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashInferImpl") - if attn_metadata is None: # Profiling run. return output + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + # The attn+quant fusion happens when output_scale is provided. + if output_scale is None: + assert attn_metadata.q_data_type != FP8_DTYPE, \ + "Query can only be FP8 if output fusion happened." + assert output_block_scale is None, "output_block_scale "\ + "is not supported when fusion has not happened" + else: + assert attn_metadata.q_data_type == FP8_DTYPE, \ + "Query must be FP8 when attn+quant fusion happened." + assert (attn_metadata.prefill_use_trtllm and + attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + + if output.dtype == FP8_DTYPE: + assert output_block_scale is None, \ + "output_block_scale should not be provided for fp8 output" + elif output.dtype == FP4_DTYPE: + assert output_block_scale is not None, \ + "output_block_scale is required for nvfp4 output" + else: + raise ValueError(f"Unsupported output dtype: {output.dtype}") + + # TRTLLM attn kernel requires o scale to pass as a host scalar, + # store the o scale as a host scalar in warmup run with cuda graph + # not enabled + if layer._o_scale_float is None: + layer._o_scale_float = output_scale.cpu().item() + if output.dtype == FP8_DTYPE: + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + elif output.dtype == FP4_DTYPE: + self.o_sf_scale = layer._o_scale_float + + # Insert FP8 quant for query + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -721,9 +735,6 @@ class FlashInferImpl(AttentionImpl): self.kv_cache_dtype) kv_cache = kv_cache.view(torch_dtype) - window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) - # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] output_padded = output @@ -751,7 +762,7 @@ class FlashInferImpl(AttentionImpl): if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal - assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._window_left == self.window_left assert prefill_wrapper._logits_soft_cap == ( self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale @@ -778,6 +789,16 @@ class FlashInferImpl(AttentionImpl): assert block_tables_prefill.is_contiguous() assert seq_lens_prefill.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape) + else: + assert self.o_sf_scale is None + out = output[num_decode_tokens:] + trtllm_batch_context_with_kv_cache( query=prefill_query, kv_cache=kv_cache_permute, @@ -786,14 +807,15 @@ class FlashInferImpl(AttentionImpl): seq_lens=seq_lens_prefill, max_q_len=attn_metadata.max_q_len, max_kv_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, batch_size=attn_metadata.num_prefills, cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, - window_left=window_left, + window_left=self.window_left, sinks=self.sinks, - out=output[num_decode_tokens:], + o_sf_scale=self.o_sf_scale, + out=out, ) if num_decode_tokens > 0: @@ -803,7 +825,7 @@ class FlashInferImpl(AttentionImpl): assert decode_wrapper is not None if not attn_metadata.decode_use_trtllm: - assert decode_wrapper._window_left == window_left + assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale @@ -818,8 +840,8 @@ class FlashInferImpl(AttentionImpl): # decode_query may be non-contiguous decode_query = decode_query.contiguous() workspace_buffer = decode_wrapper._float_workspace_buffer - block_tables_decode = attn_metadata.block_table_tensor[: - num_decode_tokens] + block_tables_decode = attn_metadata.\ + block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -830,6 +852,16 @@ class FlashInferImpl(AttentionImpl): assert block_tables_decode.is_contiguous() assert seq_lens_decode.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape) + else: + assert self.o_sf_scale is None + out = output[:num_decode_tokens] + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache_permute, @@ -837,11 +869,12 @@ class FlashInferImpl(AttentionImpl): block_tables=block_tables_decode, seq_lens=seq_lens_decode, max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, - window_left=window_left, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + window_left=self.window_left, sinks=self.sinks, - out=output[:num_decode_tokens], + o_sf_scale=self.o_sf_scale, + out=out, ) return output_padded @@ -851,6 +884,7 @@ def fast_plan_decode( indptr_cpu: torch.Tensor, indices: torch.Tensor, last_page_len_cpu: torch.Tensor, + seq_lens_cpu: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -928,9 +962,6 @@ def fast_plan_decode( kv_data_type = getattr(torch, kv_data_type) if isinstance( kv_data_type, str) else kv_data_type - if self.use_tensor_cores: - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " @@ -947,56 +978,29 @@ def fast_plan_decode( self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) - indptr_host = indptr_cpu - last_page_len_host = last_page_len_cpu + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - if self.use_tensor_cores: - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, - page_size) - - try: - # Make sure we pass exactly 15 arguments for tensor core version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, # total_num_rows - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - ) - except Exception as e: - raise RuntimeError(f"Error in tensor core plan: {e}") from e - else: - try: - # Make sure we pass exactly 15 arguments for standard version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - window_left, - logits_soft_cap, - head_dim, - head_dim, - torch.empty(0, dtype=q_data_type), - torch.empty(0, dtype=kv_data_type), - ) - except Exception as e: - raise RuntimeError(f"Error in standard plan: {e}") from e + try: + # Make sure we pass exactly 15 arguments for tensor core version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_cpu, + seq_lens_cpu, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + ) + except Exception as e: + raise RuntimeError(f"Error in tensor core plan: {e}") from e self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left @@ -1004,3 +1008,25 @@ def fast_plan_decode( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta + + +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store(page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e599411b2d7e8..458562ebc8d27 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" -from collections import defaultdict +"""Attention layer with FlexAttention.""" + from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Optional, Union import torch +import torch._dynamo.decorators +import torch.nn.functional as F from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, _score_mod_signature, create_block_mask, @@ -16,13 +18,17 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, is_quantized_kv_cache) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + create_block_mask_compiled = torch.compile(create_block_mask, fullgraph=True, mode="reduce-overhead") @@ -36,6 +42,23 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: torch.arange(len(counts), device=device, dtype=torch.int32), counts) +def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): + difference = (multiple - (x.shape[dim] % multiple)) % multiple + if difference == 0: + return x + + dim = dim if dim >= 0 else x.ndim + dim + pad_list = [] + + for i in range(x.ndim - 1, dim - 1, -1): + if i == dim: + pad_list.extend([0, difference]) + else: + pad_list.extend([0, 0]) + + return F.pad(x, pad_list, mode="constant", value=0) + + class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -77,10 +100,10 @@ class FlexAttentionBackend(AttentionBackend): return False -# @torch.compile(fullgraph=True, mode="reduce-overhead") -def physical_to_logical_mapping( - block_table: torch.Tensor, - total_blocks: Optional[int] = None) -> torch.Tensor: +#@torch.compile(fullgraph=True, mode="reduce-overhead") +def physical_to_logical_mapping(block_table: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, + total_blocks: int) -> torch.Tensor: """ Creates an inverse mapping from physical block locations to logical indices. @@ -114,13 +137,38 @@ def physical_to_logical_mapping( If a physical block is not mapped to by any logical block, its value in the result will be -1. + IMPORTANT: Garbage Value Protection + ──────────────────────────────────── + The block_table tensor may contain garbage values in unused positions + (beyond the actual sequence length). For example, if a sequence only + needs 3 blocks but the table has space for 8: + + block_table[0] = [10, 25, 7, 999, 1234, 888, ...] + ^^^^^^^^^^^^^^^^^^^^ + garbage values + + These garbage values can cause issues because: + 1. They may map to valid physical blocks by coincidence + 2. The scatter_ operation will assign them logical indices + 3. Later attention computations may incorrectly access these blocks + + To prevent this, we use seq_lens and block_size to mask out unused + entries, ensuring only valid block references are processed. Args: block_table: Tensor of shape [max_reqs, max_num_blocks] - mapping logical blocks to physical locations + mapping logical blocks to physical locations. May contain + garbage values in unused positions. + seq_lens: Tensor of sequence lengths for each request. Used to + determine how many blocks are actually needed per sequence. + block_size: Size of each block in tokens. Used with seq_lens to + compute the number of valid blocks per sequence. + total_blocks: Total number of physical blocks available Returns: - A tensor of shape [max_reqs, max_physical_block] + A tensor of shape [max_reqs, total_blocks] where each entry + physical_to_logical[req_id, physical_block] contains the logical + block index for that physical block, or -1 if unused. """ max_reqs, max_num_blocks = block_table.shape device = block_table.device @@ -130,17 +178,76 @@ def physical_to_logical_mapping( dtype=torch.long, device=device) - logical_indices = (torch.arange(max_num_blocks, - device=device).unsqueeze(0).expand( - max_reqs, -1)) + # Only process valid blocks to avoid garbage values + num_blocks_per_seq = cdiv(seq_lens, block_size) + mask = torch.arange(max_num_blocks, + device=device)[None, :] < num_blocks_per_seq[:, None] - physical_to_logical.scatter_(-1, block_table.to(torch.int64), - logical_indices) - # TODO Confirm - Seems like block 0 is always empty so we reset it manually + valid_block_table = torch.where(mask, block_table, 0) + valid_logical_indices = torch.where( + mask, + torch.arange(max_num_blocks, device=device)[None, :], 0) + + physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64), + valid_logical_indices) + # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1 return physical_to_logical +def unique_static_unsorted( + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is “skip me”) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots +) -> torch.Tensor: + """ + - Keeps the first occurrence of each non-zero value while preserving order, + then left-packs those uniques and fills the rest with `pad_val`. + - Returns (packed, keep_mask) with the *same shape* as `x`. + - Requires that all values be in the range [0, M] + - Skips ignored_val + + Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory. + + Example: + x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1] + """ + if not (-1 <= pad_val <= M): + raise ValueError("`pad_val` must lie in [-1, M]") + + # ── move `dim` to the end so we can treat tensor as [B, N] ────────── + dim = dim % x.ndim + x_perm = x.movedim(dim, -1) # shape [..., N] + B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1] + x_flat = x_perm.reshape(B, N) # [B, N] + + device = x.device + idx = torch.arange(N, device=device).expand(B, N) # per-row indices + + # ── build first-occurrence table for every v ∈ [0, M] ─────────────── + first_idx = torch.full((B, M + 1), N, device=device) # “∞” + # scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i + first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") + + # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat) + ) # [B, N] + + # ── left-pack uniques into a fresh tensor ─────────────────────────── + dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go + packed_flat = torch.full_like(x_flat, pad_val) + + rows, src_cols = torch.nonzero(keep, as_tuple=True) + packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols] + + # ── restore original layout ───────────────────────────────────────── + packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim) + return packed + + def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor): return q_idx >= kv_idx @@ -170,6 +277,7 @@ class FlexAttentionMetadata: num_reqs: int physical_to_logical: torch.Tensor decode_offset: torch.Tensor + num_blocks_per_seq: torch.Tensor # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -179,6 +287,46 @@ class FlexAttentionMetadata: block_mask: Optional[BlockMask] = None score_mod: Optional[_score_mod_signature] = None logical_mask_mod: _mask_mod_signature = causal_mask_mod + doc_ids: Optional[torch.Tensor] = None + direct_build: bool = True + q_block_size: int = 16 + kv_block_size: int = 16 + transformed_score_mod: Optional[_score_mod_signature] = None + + def _convert_physical_to_logical( + self, + request_lookup: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert physical indices to logical indices for both query and kv. + + NB is_within_lower_bound: do sequences start on block_boundaries? + + Returns: + tuple of (is_valid, logical_q_idx, logical_kv_idx) + """ + # Map query indices to corresponding request indices + q_req = request_lookup[q_idx] + + # Convert physical KV indices to logical indices + physical_kv_block = physical_kv_idx // self.block_size + physical_kv_offset = physical_kv_idx % self.block_size + logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] + logical_kv_idx = (logical_block_idx * self.block_size + + physical_kv_offset) + + # Determine valid kv indices + live_block = logical_block_idx >= 0 + within_upper_bound = logical_kv_idx < self.seq_lens[q_req] + within_lower_bound = logical_kv_idx >= 0 + is_valid = live_block & within_upper_bound & within_lower_bound + + # Convert physical query indices to logical indices + local_q_idx = q_idx - self.query_start_loc[q_req] + logical_q_idx = local_q_idx + self.decode_offset[q_req] + + return is_valid, logical_q_idx, logical_kv_idx def get_causal_mask_mod(self) -> _mask_mod_signature: """Creates the mask_mod function for FlexAttention. @@ -191,11 +339,8 @@ class FlexAttentionMetadata: With this info we create the "logical" indices that are passed to mask_mod functions. This allows mask mod functions to be agnostic to layout of the query and key/value tensors. - - TODO is_within_lower_bound: do sequences start on block_boundaries? """ - # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + assert self.doc_ids is not None def final_mask_mod( b: torch.Tensor, @@ -203,27 +348,9 @@ class FlexAttentionMetadata: q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - # Map query indices to corresponding request indices - q_req = request_lookup[q_idx] - - # Convert physical KV indices to logical indices - physical_kv_block = physical_kv_idx // self.block_size - physical_kv_offset = physical_kv_idx % self.block_size - logical_block_idx = self.physical_to_logical[q_req, - physical_kv_block] - logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501 - - # Determine valid kv indices - live_block = logical_block_idx >= 0 - within_upper_bound = logical_kv_idx < self.seq_lens[q_req] - within_lower_bound = logical_kv_idx >= 0 - - is_valid = live_block & within_upper_bound & within_lower_bound - - # Convert physical query indices to logical indices - local_q_idx = q_idx - self.query_start_loc[q_req] - logical_q_idx = local_q_idx + self.decode_offset[q_req] - + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + self.doc_ids, q_idx, physical_kv_idx) # Apply mask modification only for valid indices return torch.where( is_valid, @@ -236,7 +363,7 @@ class FlexAttentionMetadata: def get_bidirectional_mask_mod(self) -> _mask_mod_signature: """Creates the encoder mask_mod function for FlexAttention. - Since the encoder bidirectional attention doesn't run with + Since the encoder bidirectional attention doesn't run with KV cache, this function creates a mask based on the packed query sequences. """ @@ -253,6 +380,97 @@ class FlexAttentionMetadata: return final_mask_mod + def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: + """Creates the transformed score_mod function for FlexAttention. + + This function wraps the user's score_mod to handle physical-to-logical + index conversion, similar to how get_mask_mod works for mask functions. + """ + if self.score_mod is None: + return None + + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + user_score_mod = self.score_mod + + def transformed_score_mod( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx) + + return torch.where( + is_valid, + user_score_mod(score, + b, + h, + logical_q_idx, + logical_kv_idx, + physical_q=q_idx), -float('inf')) + + return transformed_score_mod + + def _build_block_mask_direct(self) -> BlockMask: + """Direct block mask construction for standard causal attention. + + This method constructs the block mask directly using + BlockMask.from_kv_blocks which is much more efficient than the + generic create_block_mask approach. + + The direct path works as follows: + 1. For each query token, fetch blocks from block_table using max_seq_len + (this fetches more blocks than needed for shorter sequences) + 2. Group query tokens into chunks of q_block_size + 3. For each group, deduplicate the blocks using unique_static_unsorted + 4. Create BlockMask using the deduplicated block indices + + Over-estimation occurs when a group of q_block_size tokens contains + multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for + each sequence represented in the group, even though individual query + tokens may only need a subset of those blocks based on causal masking + and their position. + + """ + page_to_block_ratio = self.kv_block_size // self.block_size + if page_to_block_ratio != 1: + raise ValueError( + f"FlexAttention currently requires the cache block size " + f"({self.block_size}) to be equal to the kv_block_size " + f"({self.kv_block_size}). Please check your model's " + f"configuration.") + + used_pages = self.block_table[ + self.doc_ids, :cdiv(self.max_seq_len, self.block_size)] + used_pages_padded = pad_to_multiple(used_pages, + multiple=self.q_block_size, + dim=0) + used_pages_padded = used_pages_padded.reshape( + used_pages_padded.shape[0] // self.q_block_size, -1) + used_pages_padded = used_pages_padded // page_to_block_ratio + kv_indices = unique_static_unsorted((used_pages_padded.long()), + M=self.num_blocks).to(torch.int32) + + kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) + block_mask_kwargs = { + "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), + "kv_num_blocks": kv_num_blocks[None, None], + "kv_indices": kv_indices[None, None], + "full_kv_num_blocks": None, + "full_kv_indices": None, + "BLOCK_SIZE": (self.q_block_size, self.kv_block_size), + "mask_mod": self.mask_mod, + } + + # compute_q_blocks parameter is available in PyTorch 2.9+ + if is_torch_equal_or_newer("2.9.0.dev0"): + block_mask_kwargs["compute_q_blocks"] = False + return BlockMask.from_kv_blocks(**block_mask_kwargs) + def build_block_mask(self) -> BlockMask: if self.causal: mask_mod = self.get_causal_mask_mod() @@ -267,6 +485,7 @@ class FlexAttentionMetadata: self.num_actual_tokens, kv_len, device=self.block_table.device, + BLOCK_SIZE=(self.q_block_size, self.kv_block_size), ) def __post_init__(self): @@ -275,8 +494,21 @@ class FlexAttentionMetadata: assert self.cu_prefix_query_lens is None, "Not implemented yet." assert self.prefix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet." + # Create a lookup mapping from query indices -> request number + self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.num_blocks = self.total_cache_tokens // self.block_size - self.block_mask = self.build_block_mask() + + if self.causal: + self.mask_mod = self.get_causal_mask_mod() + else: + self.mask_mod = self.get_bidirectional_mask_mod() + + self.transformed_score_mod = self.get_transformed_score_mod() + + if self.direct_build and self.causal: + self.block_mask = self._build_block_mask_direct() + else: + self.block_mask = self.build_block_mask() class FlexAttentionMetadataBuilder( @@ -287,15 +519,24 @@ class FlexAttentionMetadataBuilder( self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config + self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( - vllm_config.parallel_config) + self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads( - vllm_config.parallel_config) + self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.device = device + self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") + self.q_block_size: int = 16 if is_torch_equal_or_newer( + "2.9.0.dev0") else 128 + self.kv_block_size: int = 16 if is_torch_equal_or_newer( + "2.9.0.dev0") else 128 + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False def build(self, common_prefix_len: int, @@ -305,11 +546,12 @@ class FlexAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + num_blocks_per_seq = cdiv(seq_lens, self.block_size) use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -320,12 +562,15 @@ class FlexAttentionMetadataBuilder( block_size = self.kv_cache_spec.block_size max_possible_seq_len = self.model_config.max_model_len - total_cache_tokens = self.cache_config.num_gpu_blocks * block_size + num_gpu_blocks = self.cache_config.num_gpu_blocks + + assert num_gpu_blocks is not None, \ + "FlexAttention requires num_gpu_blocks to be set" + total_cache_tokens = (num_gpu_blocks * block_size) inverse_block_table = physical_to_logical_mapping( - block_table_tensor, self.cache_config.num_gpu_blocks) + block_table_tensor, seq_lens, block_size, num_gpu_blocks) - # Get the original offset tensor offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( self.device, non_blocking=True) @@ -349,9 +594,16 @@ class FlexAttentionMetadataBuilder( physical_to_logical=inverse_block_table, total_cache_tokens=total_cache_tokens, decode_offset=offset_tensor, + num_blocks_per_seq=num_blocks_per_seq, + direct_build=self.direct_build, + q_block_size=self.q_block_size, + kv_block_size=self.kv_block_size, ) return out + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class FlexAttentionImpl(AttentionImpl): sliding_window: Optional[tuple[int, int]] @@ -370,6 +622,7 @@ class FlexAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -398,6 +651,7 @@ class FlexAttentionImpl(AttentionImpl): raise NotImplementedError( "FlexAttention does not support logits soft cap yet.") + assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: @@ -405,7 +659,6 @@ class FlexAttentionImpl(AttentionImpl): "FlexAttention does not support kv sharing yet.") FlexAttentionBackend.validate_head_size(head_size) - if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet") @@ -428,6 +681,7 @@ class FlexAttentionImpl(AttentionImpl): attn_metadata: FlexAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -441,7 +695,7 @@ class FlexAttentionImpl(AttentionImpl): shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlexAttentionImpl") @@ -492,35 +746,48 @@ class FlexAttentionImpl(AttentionImpl): # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) - # default M=64, N=64 may run out of shared memory on some GPUs - # TODO: Explicit configs for each GPU? - # Not sure how to calculate the shared memory requirement - extra_kernel_options = defaultdict[str, int](lambda: 64) - if query.dtype == torch.float32: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 - if current_platform.is_cuda(): - device_props = torch.cuda.get_device_properties() - max_shared_memory = device_props.shared_memory_per_block_optin - if max_shared_memory < 144 * 1024: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 + assert attn_metadata.block_mask is not None + block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE + kernel_options = get_kernel_options(query, block_m, block_n, + attn_metadata.direct_build) out = flex_attention_compiled( query, key_tensor, value_tensor, - attn_metadata.score_mod, + attn_metadata.transformed_score_mod, attn_metadata.block_mask, self.scale, enable_gqa=enable_gqa, - kernel_options={ - "FORCE_USE_FLEX_ATTENTION": True, - **extra_kernel_options - }, + kernel_options=kernel_options, ) # Flex doesn't have an out variant today, rely on epilogue fusion out = out.permute(0, 2, 1, 3).squeeze(0) output[:num_actual_tokens, :, :].copy_(out) return output + + +def get_kernel_options(query, block_m, block_n, + use_direct_build: bool) -> dict[str, Union[int, bool]]: + kernel_options: dict[str, Union[int, bool]] = { + "FORCE_USE_FLEX_ATTENTION": True, + } + if use_direct_build: + kernel_options["BLOCK_M"] = block_m + kernel_options["BLOCK_N"] = block_n + return kernel_options + else: + kernel_options["BLOCK_M"] = 64 + kernel_options["BLOCK_N"] = 64 + if query.dtype == torch.float32: + kernel_options["BLOCK_M"] = 32 + kernel_options["BLOCK_N"] = 32 + # if current_platform.is_cuda(): + if torch.cuda.is_available(): + device_props = torch.cuda.get_device_properties() + max_shared_memory = device_props.shared_memory_per_block_optin + if max_shared_memory < 144 * 1024: + kernel_options["BLOCK_M"] = 32 + kernel_options["BLOCK_N"] = 32 + return kernel_options diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 6cdc509083ae9..97a1aa86dda0d 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,16 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class Mamba1AttentionBackend(AttentionBackend): @@ -31,24 +31,11 @@ class Mamba1AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int + num_padded_decodes: int class Mamba1AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba1AttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 - - def __init__( - self, - kv_cache_spec: AttentionSpec, - vllm_config: VllmConfig, - device: torch.device, - layer_names: list[str], - ): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - self.device = device - self.vllm_config = vllm_config - self.layer_names = layer_names + BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]): def build( self, @@ -67,9 +54,18 @@ class Mamba1AttentionMetadataBuilder( decode_threshold=1)) has_initial_states = None + padded_decodes = num_decodes if num_prefills > 0: has_initial_states = context_lens_tensor > 0 + elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph): + state_indices_for_decode = state_indices_tensor[:num_decodes] + padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_for_decode, non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:padded_decodes] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID return Mamba1AttentionMetadata( query_start_loc=query_start_loc, @@ -80,4 +76,5 @@ class Mamba1AttentionMetadataBuilder( num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, + num_padded_decodes=padded_decodes, ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ace078e2b27c6..ed30884fdbc94 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -2,18 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, @@ -88,29 +88,14 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba2AttentionMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - - reorder_batch_threshold: ClassVar[int] = 1 + BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") - self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size) - self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), - dtype=torch.int32, - device=device, - ) def build(self, common_prefix_len: int, @@ -187,19 +172,3 @@ class Mamba2AttentionMetadataBuilder( state_indices_tensor=state_indices_tensor, ) return attn_metadata - - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with Mamba. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py new file mode 100644 index 0000000000000..07ef7cb69a160 --- /dev/null +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import abc +from typing import ClassVar, TypeVar + +import torch + +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + +M = TypeVar("M") + + +class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): + reorder_batch_threshold: ClassVar[int] = 1 + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + self.device = device + self.vllm_config = vllm_config + self.layer_names = layer_names + + self.compilation_config = vllm_config.compilation_config + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) \ No newline at end of file diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py deleted file mode 100644 index d3a0c63c5e964..0000000000000 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend -from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend - - -def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: - if mamba_type == "mamba1": - return Mamba1AttentionBackend - if mamba_type == "mamba2": - return Mamba2AttentionBackend - if mamba_type == "linear_attention": - return LinearAttentionBackend - - raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " - "supported yet.") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f2610671f769e..ce45b34f64355 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. +* Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. @@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v -) +) return spda_o @ W_O NOTE: in the actual code, @@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O ## Chunked Prefill -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a fixed workspace size. The chunked prefill approach is as follows: -MCC Max chunk of context to process per iter, computed dynamically, +MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage q_c = h_t @ W_DQ @@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention( new_v, casual=True, return_softmax_lse=True -) +) // Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): @@ -416,7 +416,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -426,30 +425,28 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max( - 8 * self.model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() @@ -620,8 +617,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and num_prefills > 0 \ - and max_context_len_cpu > 0: + if max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code @@ -635,8 +631,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) @@ -1009,6 +1006,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill @@ -1021,12 +1019,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.chunked_context.starts[i], ) @@ -1077,6 +1077,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: assert attn_metadata.prefill is not None @@ -1099,7 +1100,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_context: suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1123,6 +1124,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, + layer: AttentionLayer, ) -> torch.Tensor: raise NotImplementedError @@ -1136,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): attn_metadata: M, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLACommonImpl") @@ -1150,6 +1153,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # same expert outputs. return output.fill_(0) + fp8_attention = self.kv_cache_dtype.startswith("fp8") + num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs @@ -1184,10 +1189,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): scale=layer._k_scale, ) + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: assert attn_metadata.decode is not None @@ -1200,7 +1208,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) + if fp8_attention: + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape([ + ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] + ]), layer._q_scale) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape( + [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) + output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 6e1e5d6533dab..8a17d3a492783 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -7,7 +7,7 @@ from typing import ClassVar, Optional import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -21,7 +21,7 @@ logger = init_logger(__name__) class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): # enable full CUDA Graph support for decode-only capture - attn_cudagraph_support: ClassVar[ + cudagraph_support: ClassVar[ AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE @@ -115,7 +115,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): self._use_old_cutlass_mla = False force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None) if force_old_cutlass: - logger.warning("Forcing old cutlass mla kernel") + logger.warning_once("Forcing old cutlass mla kernel") self._use_old_cutlass_mla = True # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging @@ -123,8 +123,8 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): # FORCE_NUM_KV_SPLITS=1 force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None) if force_num_kv_splits: - logger.warning("Forcing num_kv_splits to %d", - int(force_num_kv_splits)) + logger.warning_once("Forcing num_kv_splits to %d", + int(force_num_kv_splits)) self._num_kv_splits = int(force_num_kv_splits) else: self._num_kv_splits = -1 # => Auto-detect @@ -278,6 +278,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> torch.Tensor: if self._use_old_cutlass_mla: # TODO: Remove the old cutlass MLA kernel after more extensive diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 11674423400ce..1c50144d47900 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,8 +6,7 @@ from typing import ClassVar, Optional import torch -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -166,16 +165,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): "are not implemented for " "FlashMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashMLA V1 with FP8 KV cache not yet supported") - def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -194,6 +190,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, + descale_q=layer._q_scale.reshape(1), + descale_k=layer._k_scale.reshape(1), ) return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 082c7e6f7c62e..870cc600388e7 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -7,6 +7,7 @@ from typing import ClassVar, Optional import torch import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils import cdiv @@ -221,6 +222,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 700fce68953e5..f2974ed668d99 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -6,7 +6,7 @@ from typing import Optional import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention @@ -127,6 +127,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 9b122136afb7f..fd97db0abb84f 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -5,12 +5,6 @@ from dataclasses import dataclass from typing import Optional import torch -import torch_xla.core.xla_builder as xb -import torch_xla.experimental.custom_kernel # noqa: F401 -# Required to register custom ops. -from torch.library import impl -from torch_xla._internal.jax_workarounds import requires_jax -from torch_xla.experimental.custom_kernel import XLA_LIB from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -37,6 +31,57 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { "uint8": torch.uint8, } +try: + import tpu_commons # noqa: F401 +except ImportError: + # Lazy import torch_xla + import torch_xla.core.xla_builder as xb + import torch_xla.experimental.custom_kernel # noqa: F401 + from torch.library import impl + from torch_xla._internal.jax_workarounds import requires_jax + from torch_xla.experimental.custom_kernel import XLA_LIB + + @requires_jax + def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, num_slices_per_block: int): + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax( + kv_cache_update, + (kv, slot_mapping, kv_cache, num_kv_update_slices), { + "page_size": page_size, + "num_slices_per_block": num_slices_per_block + }) + return new_kv_cache + + + XLA_LIB.define( + "kv_cache_update_op(Tensor kv, Tensor slot_mapping," \ + "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \ + "int num_slices_per_block)" \ + "-> Tensor", ) + + @impl(XLA_LIB, "kv_cache_update_op", "XLA") + def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, + num_kv_update_slices, page_size, + num_slices_per_block) + return new_kv_cache + + @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") + def kv_cache_update_op_non_xla(kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int) -> torch.Tensor: + return kv_cache + class PallasAttentionBackend(AttentionBackend): @@ -182,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl): attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -194,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for PallasAttentionBackendImpl") @@ -313,46 +359,6 @@ def write_to_kv_cache( kv_cache.copy_(new_kv_cache) -@requires_jax -def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, page_size: int, - num_slices_per_block: int): - from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update - new_kv_cache = xb.call_jax( - kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), { - "page_size": page_size, - "num_slices_per_block": num_slices_per_block - }) - return new_kv_cache - - -XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \ - "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \ - "-> Tensor", ) - - -@impl(XLA_LIB, "kv_cache_update_op", "XLA") -def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, page_size: int, - num_slices_per_block: int) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - num_kv_update_slices, page_size, - num_slices_per_block) - return new_kv_cache - - -@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") -def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: - return kv_cache - - # We can move this function to a common utils file if it's also useful for other # hardware. def dtype_bits(dtype: torch.dtype): diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 7d09ac0a4a3a1..403ad8e88a958 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec @@ -231,7 +232,7 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True + cudagraph_support = AttentionCGSupport.ALWAYS def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -269,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor @@ -420,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl): attn_metadata: AiterFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -437,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py new file mode 100644 index 0000000000000..d80ced8ec876a --- /dev/null +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class ShortConvAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: + return ShortConvAttentionMetadataBuilder + + +@dataclass +class ShortConvAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + + query_start_loc: torch.Tensor + has_initial_states: torch.Tensor + state_indices_tensor: torch.Tensor # shape: [batch,] + + # For causal_conv1d + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None + + +class ShortConvAttentionMetadataBuilder( + AttentionMetadataBuilder[ShortConvAttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> ShortConvAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) + has_initial_states = None + if num_prefills > 0: + #[batch,] + has_initial_states_cpu = ( + common_attn_metadata. + num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) + has_initial_states = has_initial_states_cpu.to( + query_start_loc.device) + + attn_metadata = ShortConvAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + has_initial_states=has_initial_states, + state_indices_tensor=state_indices_tensor, + ) + return attn_metadata \ No newline at end of file diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 5d10e9e26082d..c93223a340839 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder( q_start_loc = common_attn_metadata.query_start_loc max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl): attn_metadata: TreeAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with TreeAttention. @@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TreeAttentionImpl") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 48a9af3decac0..b12036c599799 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor @@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TritonAttentionImpl") diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1c7d087989649..39bdbe125635b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,8 +5,7 @@ import enum import functools from abc import abstractmethod from dataclasses import dataclass, make_dataclass -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, - TypeVar) +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar import numpy as np import torch @@ -58,6 +57,8 @@ class CommonAttentionMetadata: """Total number of tokens in batch""" max_query_len: int """Longest query in batch""" + max_seq_len: int + """Longest context length in batch""" block_table_tensor: torch.Tensor slot_mapping: torch.Tensor @@ -107,6 +108,7 @@ def _make_metadata_with_slice( seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + max_seq_len = int(seq_lens_cpu.max()) num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ request_slice] @@ -128,6 +130,7 @@ def _make_metadata_with_slice( num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, ) @@ -248,19 +251,23 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): @functools.lru_cache def get_kv_cache_layout(): + # Format specified by the code. global _KV_CACHE_LAYOUT_OVERRIDE - # Override with format specified by the user. + + if _KV_CACHE_LAYOUT_OVERRIDE is not None: + cache_layout = _KV_CACHE_LAYOUT_OVERRIDE + logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \ + "Setting KV cache layout to %s.", cache_layout) + return cache_layout + + # Format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT + # When neither the user nor the override specified a layout, get default if cache_layout is None: - if envs.VLLM_USE_TRTLLM_ATTENTION: - cache_layout = "HND" - else: - cache_layout = get_kv_connector_cache_layout() + cache_layout = get_kv_connector_cache_layout() else: logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) - if _KV_CACHE_LAYOUT_OVERRIDE is not None: - cache_layout = _KV_CACHE_LAYOUT_OVERRIDE return cache_layout @@ -460,8 +467,9 @@ def make_local_attention_virtual_batches( attn_chunk_size)[arange > 0] # convert from q_seqlens to cu_seqlens_q - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ - .astype(np.int32) + cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32) + np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:]) + cu_seqlens_q_local[0] = 0 # compute the seqlens_k_local, # basically a full local attention block for all but the last block in each @@ -504,11 +512,10 @@ def make_local_attention_virtual_batches( # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] - block_indices= np.broadcast_to( - np.arange(pages_per_local_batch, dtype=np.int32), - (virtual_batches, pages_per_local_batch)) \ - + np.expand_dims(block_starts, axis=1) - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + block_indices = (block_starts[:, None] + + np.arange(pages_per_local_batch, dtype=np.int32)) + block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - + 1) batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), local_blocks * pages_per_local_batch) block_table_local = block_table[batch_indices, block_indices]\ @@ -516,6 +523,7 @@ def make_local_attention_virtual_batches( query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) + max_seq_len = int(seq_lens_cpu.max()) return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, @@ -527,41 +535,13 @@ def make_local_attention_virtual_batches( num_reqs=len(seq_lens_cpu), num_actual_tokens=common_attn_metadata.num_actual_tokens, max_query_len=seqlens_q_local.max(), + max_seq_len=max_seq_len, block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, ) -def subclass_attention_metadata_builder( - name_prefix: str, - builder_cls: type[AttentionMetadataBuilder[M]], - build_preprocess_fn: Callable[[CommonAttentionMetadata], - CommonAttentionMetadata], -) -> type[AttentionMetadataBuilder[M]]: - """ - Return a new subclass of `builder_cls` whose .build(...) method - first calls build_preprocess_fn(common_attn_metadata) on the metadata. - """ - name: str = name_prefix + builder_cls.__name__ # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False): - return builder_cls.build(self, common_prefix_len, - build_preprocess_fn(common_attn_metadata), - fast_build) - - Wrapped = type( - name, - (builder_cls, ), # inherit from the original - { - "build": build, - }) - return Wrapped # type: ignore - - def subclass_attention_backend( name_prefix: str, attention_backend_cls: type[AttentionBackend], builder_cls: type[AttentionMetadataBuilder[M]] diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index fe732c6017702..e0eb7d8be9746 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder( q_seqlens = torch.diff(q_start_loc) max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl): attn_metadata: XFormersAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with XFormers. @@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersAttentionImpl") diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ad9854dd29c38..fdd96c3e9557d 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -2,15 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from collections.abc import Iterable -from typing import Callable, Optional +from typing import Optional from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, BlockStored, KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - FreeKVCacheBlockQueue, KVCacheBlock, - generate_block_hash_extra_keys, - hash_block_tokens) + FreeKVCacheBlockQueue, KVCacheBlock) from vllm.v1.request import Request logger = init_logger(__name__) @@ -97,84 +95,39 @@ class BlockPool: self, request: Request, blocks: list[KVCacheBlock], - block_hashes: list[BlockHash], num_cached_blocks: int, num_full_blocks: int, block_size: int, kv_cache_group_id: int, - hash_fn: Callable, ) -> None: """Cache a list of full blocks for prefix caching. This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `num_cached_blocks` to - `num_full_blocks`, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. + metadata to be updated and cached. Given a request, it updates the + metadata for each block and caching it in the + `cached_block_hash_to_block`. + The block hashes values are computed by the Request object immediately + when it is created and when new tokens are appended. Args: request: The request to cache the blocks. blocks: All blocks in the request. - block_hashes: Block hashes of the blocks in the request. Note that - this list may be shorter than the blocks list. In this case the - missed block hash will be computed in this function. num_cached_blocks: The number of blocks that are already cached. num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. kv_cache_group_id: The id of the KV cache group. - hash_fn: The hash function to use for block hashes. """ if num_cached_blocks == num_full_blocks: return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] - assert len(block_hashes) >= num_cached_blocks - new_block_hashes = block_hashes[num_cached_blocks:] + assert len(request.block_hashes) >= num_full_blocks + new_block_hashes = request.block_hashes[num_cached_blocks:] - # Update the new blocks with the block hashes through the chain. - if num_cached_blocks == 0: - prev_block_hash_value = None - else: - prev_block = blocks[num_cached_blocks - 1] - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.get_hash_value() - - parent_block_hash = prev_block_hash_value new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None - - if i < len(new_block_hashes): - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption), or by other - # single_type_managers with the same block_size. - # In this case we simply reuse the block hash. - block_hash = new_block_hashes[i] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - blk_idx = num_cached_blocks + i - start_token_idx = blk_idx * block_size - end_token_idx = (blk_idx + 1) * block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == block_size, ( - f"Expected {block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, - block_tokens, extra_keys) - block_hashes.append(block_hash) + block_hash = new_block_hashes[i] # Update and added the full block to the cache. block_hash_with_group_id = BlockHashWithGroupId( @@ -184,9 +137,15 @@ class BlockPool: blk.block_id] = blk if new_hashes is not None: new_hashes.append(block_hash.hash_value) - prev_block_hash_value = block_hash.hash_value if self.enable_kv_cache_events: + if num_cached_blocks == 0: + parent_block_hash = None + else: + parent_block = blocks[num_cached_blocks - 1] + assert parent_block.block_hash is not None + parent_block_hash = parent_block.block_hash.get_hash_value() + self.kv_event_queue.append( BlockStored( block_hashes=new_hashes, @@ -339,7 +298,12 @@ class BlockPool: Returns: The KV cache usage (between 0.0 and 1.0). """ - return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks) + + # Subtract 1 to account for null block. + total_gpu_blocks = self.num_gpu_blocks - 1 + if not total_gpu_blocks: + return 0 + return 1.0 - (self.get_num_free_blocks() / total_gpu_blocks) def take_events(self) -> list[KVCacheEvent]: """Atomically takes all events and clears the queue. diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index faf5c132f8640..c9d18033a1988 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict +from collections.abc import Mapping from typing import TYPE_CHECKING from vllm.logger import init_logger @@ -31,34 +33,52 @@ class EncoderCacheManager: within requests, allowing for fine-grained memory management and enabling chunked processing of multimodal inputs. - Note that no caching is shared between requests at this time. If the same - input is used across multiple requests, it will be reprocessed for each - request. + Cache is enabled to share embeddings of same multimodal data + item (identified by their hash value) between different requests, + and eviction takes place at allocation time when there's no free + space for new embeddings. + Oldest cached embeddings with no request referenced will be first evicted. Args: cache_size: Limit the size of the cache, measured by the number of tokens from the input sequence. Attributes: - cache_size: Total cache capacity in encoder tokens - num_free_slots: Current available cache capacity in encoder tokens - cached: Mapping from request_id to set of cached input_ids for that - request - freed: List of (request_id, input_id) pairs that were recently freed. - This is cleared after every call to get_freed_ids(). + cache_size: Total cache capacity in encoder tokens. + num_free_slots: Current available cache capacity in encoder tokens. + num_freeable_slots: Capacity that can be immediately reclaimed by + evicting entries with zero references (in encoder tokens). + cached: Mapping from mm_hash to a set of request IDs that currently + reference the cached entry. If the set is empty, the entry exists + but is not referenced by any request and is eligible for + reclamation. + freeable: List of tuples (mm_hash, num_tokens) representing entries + whose no current running request is needed and that can be freed to + make space when needed. + freed: List of mm_hash strings that were actually evicted since the + last call to get_freed_mm_hashes(). This list is cleared on return. """ def __init__(self, cache_size: int): self.cache_size = cache_size self.num_free_slots = cache_size - # req_id -> cached input ids - self.cached: dict[str, set[int]] = {} - # list of [req_id, input_id] - self.freed: list[tuple[str, int]] = [] + self.num_freeable_slots = cache_size - def has_cache(self, request: Request, input_id: int) -> bool: + # mm_hash of mm_data => ids of requests that reference the mm_data + self.cached: dict[str, set[str]] = {} + + # mm_hash of mm_data => num_encoder_tokens of the mm_data + self.freeable: OrderedDict[str, int] = OrderedDict() + self.freed: list[str] = [] + + def check_and_update_cache(self, request: Request, input_id: int) -> bool: """Check if encoder output for a specific multimodal input is cached. + If the encoder output is cached, update `cached` to add the request id + to the set of request ids that reference the cached encoder output. + If the encoder output was previously not referenced by any request, + update `freeable` and `num_freeable_slots` accordingly. + Args: request: The request containing the multimodal input input_id: Index of the multimodal input within the request @@ -66,103 +86,159 @@ class EncoderCacheManager: Returns: True if the encoder output for this input is already cached """ - req_id = request.request_id - return req_id in self.cached and input_id in self.cached[req_id] + mm_hash = request.mm_hashes[input_id] + # Not cached at all + if mm_hash not in self.cached: + return False - def can_allocate(self, request: Request, input_id: int) -> bool: - """Check if there's sufficient cache space for a multimodal input. + # Cached but currently not referenced by any request + if not self.cached[mm_hash]: + num_tokens = self.freeable.pop(mm_hash) + self.num_freeable_slots -= num_tokens + + self.cached[mm_hash].add(request.request_id) + return True + + def can_allocate(self, request: Request, input_id: int, + encoder_compute_budget: int, + num_tokens_to_schedule: int) -> bool: + """Check if there's sufficient cache space for a multimodal input. + If there is, return True and update EncoderCacheManager state. + + If there is not enough free space in `num_free_slots` but there is + enough reclaimable space in `num_freeable_slots`, entries will be + evicted from `freeable` (their mm_hash appended to `freed`) until + enough space is available, and then this method returns True. + Older entries are evicted first. + + Returns False only if the requested number of tokens exceeds both + the free and reclaimable capacities combined. Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + request: The request containing the multimodal input. + input_id: Index of the multimodal input within the request. + encoder_compute_budget: Number of encoder tokens allowed to be + computed when this method is invoked. + num_tokens_to_schedule: Number of tokens already scheduled to be + allocated with cache space when this method is invoked. Returns: - True if there's enough free cache space to store the encoder output - for this multimodal input + True if there's enough capacity to hold the encoder output for this + input (possibly after reclaiming `freeable` entries); otherwise + False. + + Note: This method does not allocate physical memory for the encoder + output but only the state of EncoderCacheManager. """ num_tokens = request.get_num_encoder_tokens(input_id) - return num_tokens <= self.num_free_slots + + # Not enough compute budget + if num_tokens > encoder_compute_budget: + return False + + num_tokens += num_tokens_to_schedule + + # Enough free slots + if num_tokens <= self.num_free_slots: + return True + + # Not enough reclaimable slots + if num_tokens > self.num_freeable_slots: + return False + + # Not enough free slots but enough reclaimable slots + # NOTE: Eviction takes place here, but physical memory is not freed + # until model runner is notified by the scheduler output. + while num_tokens > self.num_free_slots: + mm_hash, num_free_token = self.freeable.popitem(last=False) + del self.cached[mm_hash] + self.freed.append(mm_hash) + self.num_free_slots += num_free_token + return True def allocate(self, request: Request, input_id: int) -> None: """Allocate cache space for a multimodal input's encoder output. - This method reserves cache space for storing the encoder output of - the specified multimodal input. The actual encoder output storage - happens in the model runner, but this method ensures the cache - manager tracks the allocation. - - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + This reserves cache space for storing the encoder output of the + specified multimodal input. The actual encoder output storage happens in + the model runner; this method updates the manager's bookkeeping. Note: - This method assumes can_allocate() returned True for the same - request and input_id. It will reduce available cache space. + This method assumes can_allocate() returned True for the same input. """ - req_id = request.request_id - if req_id not in self.cached: - self.cached[req_id] = set() - self.cached[req_id].add(input_id) - self.num_free_slots -= request.get_num_encoder_tokens(input_id) + + mm_hash = request.mm_hashes[input_id] + request_id = request.request_id + if mm_hash not in self.cached: + self.cached[mm_hash] = set() + + num_encoder_tokens = request.get_num_encoder_tokens(input_id) + + # NOTE: Encoder cache should always have enough space for encoder inputs + # that are scheduled since eviction takes place at can_allocate(). + assert self.num_free_slots >= num_encoder_tokens + assert self.num_freeable_slots >= num_encoder_tokens + + self.cached[mm_hash].add(request_id) + self.num_free_slots -= num_encoder_tokens + self.num_freeable_slots -= num_encoder_tokens def get_cached_input_ids(self, request: Request) -> set[int]: """Get all cached multimodal input IDs for a request. - Args: - request: The request to query - - Returns: - Set of input_ids that have cached encoder outputs for this request. - Returns empty set if no inputs are cached for this request. + Returns the set of input IDs whose `mm_hash` exists in the cache map. + This includes entries that are currently unreferenced (and thus present + in `freeable`); for such entries, freeing for this request will be a + no-op. """ - return self.cached.get(request.request_id, set()) + return { + input_id + for input_id in range(len(request.mm_hashes)) + if request.mm_hashes[input_id] in self.cached + } def free_encoder_input(self, request: Request, input_id: int) -> None: - """Free cache space for a single multimodal input's encoder output. + """Free the request's reference to the encoder input (`mm_data`) - This method is called when: - - The encoder output has been fully consumed by the decoder and is - no longer needed (e.g., in vision-language models after image - tokens are processed) - - A request is being cancelled or aborted + When the reference set for the corresponding `mm_hash` becomes empty, + the entry is appended to `freeable` and `num_freeable_slots` is + increased by the number of encoder tokens for that input. - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input to free from cache + The entry is NOT physically freed until capacity is needed (e.g., by + `can_allocate`). """ req_id = request.request_id - if req_id not in self.cached: + mm_hash = request.mm_hashes[input_id] + # The mm_hash not in cache or the req_id set is empty + if not self.cached.get(mm_hash, None): return - - self.cached[req_id].discard(input_id) - if len(self.cached[req_id]) == 0: - del self.cached[req_id] - self.num_free_slots += request.get_num_encoder_tokens(input_id) - self.freed.append((req_id, input_id)) + self.cached[mm_hash].discard(req_id) + if not self.cached[mm_hash]: + num_tokens = request.get_num_encoder_tokens(input_id) + self.freeable[mm_hash] = num_tokens + self.num_freeable_slots += num_tokens def free(self, request: Request) -> None: - """Free all cached encoder outputs for a request. + """Free all encoder input cache reference held by *request*. - This method is typically called when a request is finished, cancelled, - or aborted, and all its encoder outputs should be freed from cache. + For each cached input ID, `free_encoder_input` is invoked. + The data stays in memory until eviction is triggered by a future + attempt allocation called by 'can_allocate'. - Args: - request: The request whose encoder outputs should be freed + Typically called when a request is finished, cancelled, or aborted. """ input_ids = self.get_cached_input_ids(request).copy() for input_id in input_ids: self.free_encoder_input(request, input_id) - def get_freed_ids(self) -> list[tuple[str, int]]: + def get_freed_mm_hashes(self) -> list[str]: """Get and clear the list of recently freed encoder cache entries. - This method returns all encoder cache entries that were freed since - the last call to this method. It's used by the scheduler to notify - workers about which encoder outputs can be removed from their caches. - Returns: - List of (request_id, input_id) tuples that were freed since the - last call. The internal freed list is cleared after this call. + List of mm_hash strings that were actually evicted since the last + call to be used by the scheduler to notify workers about which + encoder outputs can be removed from their caches. The internal + list is cleared after this call. """ freed = self.freed self.freed = [] @@ -177,10 +253,31 @@ def compute_encoder_budget( """Compute the encoder cache budget based on the model and scheduler configurations. + Returns: + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. + """ + if mm_registry.supports_multimodal_inputs(model_config): + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config) + + return compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, + ) + + return compute_text_encoder_budget(scheduler_config) + + +def compute_text_encoder_budget( + scheduler_config: "SchedulerConfig") -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler + configurations for a text-only model. + Args: - model_config: Model configuration. scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. Returns: - Compute budget for encoder execution, in unit of number of tokens @@ -188,55 +285,37 @@ def compute_encoder_budget( - Space budget for encoder cache size, in unit of number of tokens in the input sequence. """ - - if not mm_registry.supports_multimodal_inputs(model_config): - return 0, 0 - - # TODO: handle encoder-decoder models once we support them. - ( - encoder_compute_budget, - encoder_cache_size, - ) = _compute_encoder_budget_multimodal( - model_config, - scheduler_config, - mm_registry, - ) - - return encoder_compute_budget, encoder_cache_size + # Currently text-only encoder-decoder models are not supported + return 0, 0 -def _compute_encoder_budget_multimodal( - model_config: "ModelConfig", +def compute_mm_encoder_budget( scheduler_config: "SchedulerConfig", - mm_registry: MultiModalRegistry, + max_tokens_by_modality: Mapping[str, int], ) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. Args: - model_config: Model configuration. scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. + max_tokens_by_modality: The maximum number of tokens for each + non-text modality. Returns: - - Compute budget for encoder execution, in unit of number of tokens - in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens - in the input sequence. + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. """ - max_tokens_by_modality_dict = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) - - if not max_tokens_by_modality_dict: + if not max_tokens_by_modality: logger.warning( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " "not be initialized.") return 0, 0 - _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), - key=lambda item: item[1]) + max_tokens_per_mm_item = max(max_tokens_by_modality.values()) if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens): diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index f3a16d64e19fd..f082ad00f2e35 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - FullAttentionManager, get_manager_for_kv_cache_spec) + CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.request import Request @@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC): max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool, ): self.kv_cache_config = kv_cache_config @@ -40,13 +39,13 @@ class KVCacheCoordinator(ABC): kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=i, - caching_hash_fn=caching_hash_fn, ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) - def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: + def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int, + new_computed_blocks: tuple[ + list[KVCacheBlock], ...], + num_encoder_tokens: int) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -56,14 +55,22 @@ class KVCacheCoordinator(ABC): tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. + num_encoder_tokens: The number of encoder tokens for allocating + blocks for cross-attention. Returns: The number of blocks. """ num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): - num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + if isinstance(manager, CrossAttentionManager): + # For cross-attention, we issue a single static allocation + # of blocks based on the number of encoder input tokens. + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_encoder_tokens, []) + else: + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks[i]) return num_blocks_to_allocate def save_new_computed_blocks( @@ -81,8 +88,11 @@ class KVCacheCoordinator(ABC): manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> tuple[list[KVCacheBlock], ...]: + def allocate_new_blocks( + self, + request_id: str, + num_tokens: int, + num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -91,27 +101,29 @@ class KVCacheCoordinator(ABC): request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). + num_encoder_tokens: The number of encoder tokens for allocating + blocks for cross-attention. Returns: The new allocated blocks. """ return tuple( - manager.allocate_new_blocks(request_id, num_tokens) + manager.allocate_new_blocks( + request_id, num_encoder_tokens if isinstance( + manager, CrossAttentionManager) else num_tokens) for manager in self.single_type_managers) - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_computed_tokens: int) -> None: + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ Cache the blocks for the request. Args: request: The request. - block_hashes: The block hashes of the request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ for manager in self.single_type_managers: - manager.cache_blocks(request, block_hashes, num_computed_tokens) + manager.cache_blocks(request, num_computed_tokens) def free(self, request_id: str) -> None: """ @@ -184,10 +196,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, caching_hash_fn: Callable, - enable_kv_cache_events: bool): + use_eagle: bool, enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, False, - caching_hash_fn, enable_kv_cache_events) + enable_kv_cache_events) self.num_single_type_manager = len(self.single_type_managers) def get_num_common_prefix_blocks(self, request_id: str, @@ -213,10 +224,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool): + enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ 0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size @@ -250,10 +260,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool): + enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -386,17 +395,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, - enable_caching: bool, caching_hash_fn: Callable, + enable_caching: bool, enable_kv_cache_events: bool) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, - use_eagle, caching_hash_fn, + use_eagle, enable_kv_cache_events) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching, - caching_hash_fn, enable_kv_cache_events) return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ce333dbe61a19..b427a9c497fef 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,16 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Literal, Optional, overload from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger -from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - hash_request_tokens, init_none_hash) +from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -40,7 +37,24 @@ class KVCacheBlocks: tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks))) - def get_block_ids(self) -> tuple[list[int], ...]: + @overload + def get_block_ids( + self, + allow_none: Literal[False] = False, + ) -> tuple[list[int], ...]: + ... + + @overload + def get_block_ids( + self, + allow_none: Literal[True] = True, + ) -> Optional[tuple[list[int], ...]]: + ... + + def get_block_ids( + self, + allow_none: bool = False, + ): """ Converts the KVCacheBlocks instance to block_ids. @@ -49,6 +63,8 @@ class KVCacheBlocks: * the outer tuple corresponds to KV cache groups * each inner list contains the block_ids of the blocks in that group """ + if allow_none and all(len(group) == 0 for group in self.blocks): + return None return tuple([blk.block_id for blk in group] for group in self.blocks) def get_unhashed_block_ids(self) -> list[int]: @@ -71,23 +87,13 @@ class KVCacheManager: kv_cache_config: KVCacheConfig, max_model_len: int, enable_caching: bool = True, - caching_hash_algo: str = "builtin", use_eagle: bool = False, log_stats: bool = False, enable_kv_cache_events: bool = False, ) -> None: self.max_model_len = max_model_len - if len(kv_cache_config.kv_cache_groups) == 0: - # Attention free models don't have kv cache, - # thus don't need prefix caching. - enable_caching = False self.enable_caching = enable_caching - - self.caching_hash_fn = ( - sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else - sha256 if caching_hash_algo == "sha256" else hash) - init_none_hash(self.caching_hash_fn) self.use_eagle = use_eagle self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats @@ -107,19 +113,12 @@ class KVCacheManager: max_model_len=self.max_model_len, use_eagle=self.use_eagle, enable_caching=self.enable_caching, - caching_hash_fn=self.caching_hash_fn, enable_kv_cache_events=enable_kv_cache_events, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config - # Mapping from request ID to kv block hashes. - # This is to avoid recomputing the block hashes for each call of - # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[BlockHash]] = defaultdict(list) - @property def usage(self) -> float: """Get the KV cache usage. @@ -161,15 +160,6 @@ class KVCacheManager: and request.sampling_params.prompt_logprobs is not None)): return self.create_empty_block_list(), 0 - # The block hashes for the request may already be computed - # if the scheduler has tried to schedule the request before. - block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: - assert self.block_size is not None - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) - self.req_to_block_hashes[request.request_id] = block_hashes - # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # This can trigger recomputation of an entire block, rather than just @@ -178,7 +168,7 @@ class KVCacheManager: # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(block_hashes, + self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) if self.log_stats: @@ -197,6 +187,7 @@ class KVCacheManager: new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, + num_encoder_tokens: int = 0, ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. @@ -263,6 +254,7 @@ class KVCacheManager: request_id=request.request_id, num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, + num_encoder_tokens=num_encoder_tokens, ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -283,7 +275,7 @@ class KVCacheManager: new_computed_block_list) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot) + request.request_id, num_tokens_need_slot, num_encoder_tokens) # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. @@ -296,17 +288,13 @@ class KVCacheManager: # at `request.num_tokens`, ensuring only "finalized" tokens are cached. num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, request.num_tokens) - self.coordinator.cache_blocks( - request, - self.req_to_block_hashes[request.request_id], - num_tokens_to_cache, - ) + self.coordinator.cache_blocks(request, num_tokens_to_cache) return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. - We free the blocks in reverse order so that he tail blocks are evicted + We free the blocks in reverse order so that the tail blocks are evicted first when caching is enabled. Args: @@ -373,14 +361,6 @@ class KVCacheManager: return self.coordinator.get_num_common_prefix_blocks( request.request_id, num_running_requests) - def free_block_hashes(self, request: Request) -> None: - """Discard the block hashes for the request. - - NOTE: Unlike `free`, this method should be called only when the request - is finished, not when it is preempted. - """ - self.req_to_block_hashes.pop(request.request_id, None) - def take_events(self) -> list[KVCacheEvent]: """Take the KV cache events from the block pool. @@ -389,17 +369,18 @@ class KVCacheManager: """ return self.block_pool.take_events() + def get_blocks(self, request_id: str) -> KVCacheBlocks: + """Get the blocks of a request.""" + return KVCacheBlocks(self.coordinator.get_blocks(request_id)) + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" - return KVCacheBlocks( - self.coordinator.get_blocks(request_id)).get_block_ids() + return self.get_blocks(request_id).get_block_ids() def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """Cache the blocks for the request, if enabled.""" if self.enable_caching: - block_hashes = self.req_to_block_hashes[request.request_id] - self.coordinator.cache_blocks(request, block_hashes, - num_computed_tokens) + self.coordinator.cache_blocks(request, num_computed_tokens) def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 626aa35a770c9..6a62c55fb2d5f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -547,41 +547,61 @@ def hash_block_tokens( curr_block_token_ids_tuple, extra_keys) -def hash_request_tokens(hash_function: Any, block_size: int, - request: Request) -> list[BlockHash]: - """Computes hash values of a chain of blocks given a sequence of - token IDs. The hash value is used for prefix caching. - - Args: - block_size: The size of each block. - request: The request object. - - Returns: - The list of computed hash values. +def get_request_block_hasher( + block_size: int, + caching_hash_fn: Callable[[Any], + int]) -> Callable[[Request], list[BlockHash]]: """ - token_ids = request.all_token_ids + Returns a function which computes the list of un-computed block hashes + of a request. - req_need_extra_keys = need_extra_keys(request) - req_extra_keys = None - curr_mm_idx = 0 + Each request holds a list of its block hashes (request.block_hashes). + When a request is created, it calls the below function to compute + the hashes of all full blocks of the request's initial tokens. + The hashes are then stored in request.block_hashes. + Later, whenever new tokens are appended to the request, it calls + the below function again to compute any new full blocks of tokens. + The returned new hashes are appended to request.block_hashes. + """ - ret = [] - parent_block_hash_value = None - # Only full blocks will be hashed - for start in range(0, len(token_ids) - block_size + 1, block_size): - end = start + block_size - block_token_ids = token_ids[start:end] + def request_block_hasher(request: Request) -> list[BlockHash]: + start_token_idx = len(request.block_hashes) * block_size + num_tokens = request.num_tokens + + curr_mm_idx = 0 + if start_token_idx > 0: + # Set curr_mm_idx = -1 to indicate the last mm input. + # Note that since we reach to this branch only when the block is + # completed with generated tokens, we only need to consider the + # last mm input. + curr_mm_idx = -1 + + prev_block_hash_value = request.block_hashes[-1].hash_value \ + if request.block_hashes else None + new_block_hashes: list[BlockHash] = [] + while True: + end_token_idx = start_token_idx + block_size + if end_token_idx > num_tokens: + # We only hash full blocks + break - if req_need_extra_keys: # MM and LoRA requests need extra keys for block-hash computation. - req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start, end, curr_mm_idx) + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, curr_mm_idx) - block_hash = hash_block_tokens(hash_function, parent_block_hash_value, - block_token_ids, req_extra_keys) - ret.append(block_hash) - parent_block_hash_value = block_hash.hash_value - return ret + # Compute the hash of the current block + block_tokens = request.all_token_ids[start_token_idx:end_token_idx] + block_hash = hash_block_tokens(caching_hash_fn, + prev_block_hash_value, block_tokens, + extra_keys) + + new_block_hashes.append(block_hash) + start_token_idx += block_size + prev_block_hash_value = block_hash.hash_value + + return new_block_hashes + + return request_block_hasher def max_memory_usage_bytes(vllm_config: VllmConfig, diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index dd5052a3480b7..5b1de3a66ceb4 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats - from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -61,6 +61,14 @@ class SchedulerInterface(ABC): """ raise NotImplementedError + @abstractmethod + def update_draft_token_ids( + self, + draft_token_ids: "DraftTokenIds", + ) -> None: + """Update the draft token ids for the scheduled requests.""" + raise NotImplementedError + @abstractmethod def add_request(self, request: "Request") -> None: """Add a new request to the scheduler's internal queue. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index fac07f97195bd..b5cd6c5c8af51 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -91,7 +91,7 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] - new_block_ids: list[tuple[list[int], ...]] + new_block_ids: list[Optional[tuple[list[int], ...]]] num_computed_tokens: list[int] @property @@ -143,9 +143,9 @@ class SchedulerOutput: # steps. This is used to notify the workers about the finished requests # so that they can free the cached states for those requests. finished_req_ids: set[str] - # list of (req_id, encoder_input_index) tuples. - # Used to free the encoder cache. - free_encoder_input_ids: list[tuple[str, int]] + # list of mm_hash strings associated with the encoder outputs to be + # freed from the encoder cache. + free_encoder_mm_hashes: list[str] # Dict of request ids to their index within the batch # for filling the next token bitmask diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dcb9f4dd36f52..14a914d8f2f0b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,18 +19,18 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.core.sched.request_queue import (SchedulingPolicy, create_request_queue) -from vllm.v1.core.sched.utils import check_stop +from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager @@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface): self.parallel_config = vllm_config.parallel_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder # include_finished_set controls whether a separate set of finished # request ids should be included in the EngineCoreOutputs returned @@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface): assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " "with KV connectors") + assert not self.is_encoder_decoder, ( + "Encoder-decoder models are not currently supported " + "with KV connectors") self.connector = KVConnectorFactory.create_connector( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) @@ -141,7 +145,6 @@ class Scheduler(SchedulerInterface): cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config - self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: @@ -155,7 +158,6 @@ class Scheduler(SchedulerInterface): kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, @@ -179,20 +181,12 @@ class Scheduler(SchedulerInterface): scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] - # NOTE: structured_output_request_ids maps - # a request's (request that uses structured output) - # request_id to the running request index. - # This will helps us determine to slice the grammar bitmask - # and only applies valid mask for requests that - # uses structured decoding. - structured_output_request_ids: dict[str, int] = {} - - req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} + req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} - encoder_budget = self.max_num_encoder_input_tokens + encoder_compute_budget = self.max_num_encoder_input_tokens # Spec decode-related. scheduled_spec_decode_tokens: dict[str, list[int]] = {} @@ -221,12 +215,13 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, num_new_tokens, - encoder_budget) + encoder_compute_budget) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -262,6 +257,7 @@ class Scheduler(SchedulerInterface): preempted_req = self.running.pop() self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.log_stats: @@ -284,14 +280,7 @@ class Scheduler(SchedulerInterface): # Schedule the request. scheduled_running_reqs.append(request) - if request.use_structured_output: - # PERF: in case of chunked prefill, - # request might not include any new tokens. - # Therefore, we might introduce some additional - # cycle to fill in the bitmask, which could be a big no-op. - structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[request.request_id] = ( - new_blocks.get_block_ids()) + req_to_new_blocks[request.request_id] = new_blocks num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -314,7 +303,7 @@ class Scheduler(SchedulerInterface): # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + encoder_compute_budget = new_encoder_compute_budget # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() @@ -398,7 +387,7 @@ class Scheduler(SchedulerInterface): num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + new_encoder_compute_budget = encoder_compute_budget # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: @@ -429,10 +418,10 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget + new_encoder_compute_budget ) = self._try_schedule_encoder_inputs( request, num_computed_tokens, num_new_tokens, - encoder_budget) + encoder_compute_budget) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -446,6 +435,22 @@ class Scheduler(SchedulerInterface): == 0 else self.num_lookahead_tokens) + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + assert ("whisper" + in self.vllm_config.model_config.model.lower()), ( + "Whisper is the only supported " + "encoder-decoder model.") + num_encoder_tokens = MULTIMODAL_REGISTRY.\ + get_encdec_max_encoder_len( + self.vllm_config.model_config) + else: + num_encoder_tokens = 0 + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, @@ -453,6 +458,7 @@ class Scheduler(SchedulerInterface): new_computed_blocks, num_lookahead_tokens=effective_lookahead_tokens, delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, ) if new_blocks is None: @@ -480,9 +486,6 @@ class Scheduler(SchedulerInterface): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue - if request.use_structured_output: - structured_output_request_ids[request.request_id] = ( - req_index) req_index += 1 self.running.append(request) if self.log_stats: @@ -498,8 +501,8 @@ class Scheduler(SchedulerInterface): if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = ( - self.kv_cache_manager.get_block_ids(request.request_id)) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -514,7 +517,7 @@ class Scheduler(SchedulerInterface): # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + encoder_compute_budget = new_encoder_compute_budget # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: @@ -541,15 +544,10 @@ class Scheduler(SchedulerInterface): self.kv_cache_manager.get_num_common_prefix_blocks( any_request, len(self.running))) - grammar_bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) # Construct the scheduler output. new_reqs_data = [ - NewRequestData.from_request(req, - req_to_new_block_ids[req.request_id]) + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -557,8 +555,11 @@ class Scheduler(SchedulerInterface): scheduled_resumed_reqs, num_scheduled_tokens, scheduled_spec_decode_tokens, - req_to_new_block_ids, + req_to_new_blocks, ) + structured_output_request_ids, grammar_bitmask = ( + self.get_grammar_bitmask(self.running, + scheduled_spec_decode_tokens)) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -572,7 +573,8 @@ class Scheduler(SchedulerInterface): # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -630,11 +632,11 @@ class Scheduler(SchedulerInterface): resumed_reqs: list[Request], num_scheduled_tokens: dict[str, int], spec_decode_tokens: dict[str, list[int]], - req_to_new_block_ids: dict[str, tuple[list[int], ...]], + req_to_new_blocks: dict[str, KVCacheBlocks], ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] - new_block_ids: list[tuple[list[int], ...]] = [] + new_block_ids: list[Optional[tuple[list[int], ...]]] = [] num_computed_tokens: list[int] = [] use_connector = self.connector is not None @@ -657,7 +659,8 @@ class Scheduler(SchedulerInterface): # out of bounds errors. TODO: Remove this once the KVConnector # is updated to handle token IDs properly. new_token_ids.append([]) - new_block_ids.append(req_to_new_block_ids[req_id]) + new_block_ids.append( + req_to_new_blocks[req_id].get_block_ids(allow_none=True)) num_computed_tokens.append(req.num_computed_tokens) # Because resumed_reqs is usually empty, it is more efficient to do # in-place appending so that we don't need to allocate a new list. @@ -677,7 +680,7 @@ class Scheduler(SchedulerInterface): request: Request, num_computed_tokens: int, num_new_tokens: int, - encoder_budget: int, + encoder_compute_budget: int, ) -> tuple[list[int], int, int]: """ Determine which encoder inputs need to be scheduled in the current step, @@ -699,11 +702,17 @@ class Scheduler(SchedulerInterface): blocks and externally cached blocks (via KVConnector). """ if num_new_tokens == 0 or not request.has_encoder_inputs: - return [], num_new_tokens, encoder_budget + return [], num_new_tokens, encoder_compute_budget encoder_inputs_to_schedule: list[int] = [] mm_positions = request.mm_positions assert mm_positions is not None assert len(mm_positions) > 0 + + # NOTE: since scheduler operates on the request level (possibly with + # multiple encoder inputs per request), we need to create temporary + # trackers for accounting at the encoder input level. + mm_hashes_to_schedule = set() + num_tokens_to_schedule = 0 for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -714,13 +723,34 @@ class Scheduler(SchedulerInterface): if start_pos >= num_computed_tokens + num_new_tokens: # The encoder input is not needed in this step. break - if start_pos + num_encoder_tokens <= num_computed_tokens: + + if self.is_encoder_decoder and num_computed_tokens > 0: + assert start_pos == 0, ( + "Encoder input should be processed at the beginning of " + "the sequence when encoder-decoder models are used.") + # Encoder input has already been computed + # The calculation here is a bit different. We don't turn encoder + # output into tokens that get processed by the decoder and + # reflected in num_computed_tokens. Instead, start_pos reflects + # the position where we need to ensure we calculate encoder + # inputs. This should always be 0 to ensure we calculate encoder + # inputs before running the decoder. Once we've calculated some + # decoder tokens (num_computed_tokens > 0), then we know we + # already calculated encoder inputs and can skip here. + continue + elif start_pos + num_encoder_tokens <= num_computed_tokens: # The encoder input is already computed and stored # in the decoder's KV cache. continue - if self.encoder_cache_manager.has_cache(request, i): - # The encoder input is already computed and cached. + # The same encoder input has already been scheduled in the current + # step. + if request.mm_hashes[i] in mm_hashes_to_schedule: + continue + + if self.encoder_cache_manager.check_and_update_cache(request, i): + # The encoder input is already computed and cached from a + # previous step. continue # If no encoder input chunking is allowed, we do not want to @@ -733,8 +763,9 @@ class Scheduler(SchedulerInterface): num_new_tokens = start_pos - num_computed_tokens break - if (not self.encoder_cache_manager.can_allocate(request, i) - or num_encoder_tokens > encoder_budget): + if not self.encoder_cache_manager.can_allocate( + request, i, encoder_compute_budget, + num_tokens_to_schedule): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -751,9 +782,46 @@ class Scheduler(SchedulerInterface): num_new_tokens = 0 break - encoder_budget -= num_encoder_tokens + num_tokens_to_schedule += num_encoder_tokens + encoder_compute_budget -= num_encoder_tokens + mm_hashes_to_schedule.add(request.mm_hashes[i]) encoder_inputs_to_schedule.append(i) - return encoder_inputs_to_schedule, num_new_tokens, encoder_budget + + return ( + encoder_inputs_to_schedule, + num_new_tokens, + encoder_compute_budget, + ) + + def get_grammar_bitmask( + self, + requests: list[Request], + scheduled_spec_decode_tokens: dict[str, list[int]], + ): + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to its index in the batch. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + for i, req in enumerate(requests): + if req.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[req.request_id] = i + + if not structured_output_request_ids: + bitmask = None + else: + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + return structured_output_request_ids, bitmask def update_from_output( self, @@ -761,7 +829,6 @@ class Scheduler(SchedulerInterface): model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids - spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -846,20 +913,9 @@ class Scheduler(SchedulerInterface): request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] req_id, new_token_ids) - # spec_token_ids comes from the model runner output if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] - # Add newly generated spec token ids to the request. - if spec_token_ids is not None: - if self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids[req_index]) - else: - request.spec_token_ids = spec_token_ids[req_index] - # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ @@ -886,9 +942,7 @@ class Scheduler(SchedulerInterface): # Remove the stopped requests from the running and waiting queues. if stopped_running_reqs: - self.running = [ - req for req in self.running if req not in stopped_running_reqs - ] + self.running = remove_all(self.running, stopped_running_reqs) if stopped_preempted_reqs: # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) @@ -918,10 +972,13 @@ class Scheduler(SchedulerInterface): finished_requests=finished_set) finished_req_ids.clear() - if engine_core_outputs: + if (stats := self.make_stats(spec_decoding_stats)) is not None: # Return stats to only one of the front-ends. - next(iter(engine_core_outputs.values())).scheduler_stats = ( - self.make_stats(spec_decoding_stats)) + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats return engine_core_outputs @@ -964,6 +1021,30 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager.free_encoder_input( request, input_id) + def update_draft_token_ids( + self, + draft_token_ids: DraftTokenIds, + ) -> None: + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + request = self.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + continue + + # Add newly generated spec token ids to the request. + if not spec_token_ids: + # NOTE(woosuk): request.spec_token_ids should be updated. + request.spec_token_ids.clear() + elif self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids) + else: + request.spec_token_ids = spec_token_ids + def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" return len(self.running), len(self.waiting) @@ -990,7 +1071,7 @@ class Scheduler(SchedulerInterface): else: request_ids = set(request_ids) - running_requests_to_remove = [] + running_requests_to_remove = set() waiting_requests_to_remove = [] valid_requests = [] @@ -1003,13 +1084,13 @@ class Scheduler(SchedulerInterface): valid_requests.append(request) if request.status == RequestStatus.RUNNING: - running_requests_to_remove.append(request) + running_requests_to_remove.add(request) else: waiting_requests_to_remove.append(request) # Remove all requests from queues at once for better efficiency - for request in running_requests_to_remove: - self.running.remove(request) + if running_requests_to_remove: + self.running = remove_all(self.running, running_requests_to_remove) if waiting_requests_to_remove: self.waiting.remove_requests(waiting_requests_to_remove) @@ -1036,7 +1117,6 @@ class Scheduler(SchedulerInterface): def _free_blocks(self, request: Request): assert request.is_finished() self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] def get_num_unfinished_requests(self) -> int: diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 42ec95091f962..42d3e5c68b4c8 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from typing import Optional import torch @@ -7,6 +8,38 @@ import torch from vllm.v1.request import Request, RequestStatus +def remove_all(lst: list, items_to_remove: set) -> list: + """Remove all items from a list that are in the items_to_remove set. + + This method optimizes for the common case of removing a single item, + falling back to list comprehension for multiple items. + + Args: + lst: The list to remove items from + items_to_remove: Set of items to remove + + Returns: + Either the modified original list (for single item removal) or + a new list (for multiple item removal). Callers should use the + returned value. + + Note: + For single item removal, this modifies the original list in-place + and returns it. For multiple items, it creates and returns a new list. + """ + if not items_to_remove: + return lst + + if len(items_to_remove) == 1: + # Fast path for single item removal (most common case) + item = next(iter(items_to_remove)) + with contextlib.suppress(ValueError): + lst.remove(item) + return lst + # For multiple items, use list comprehension + return [item for item in lst if item not in items_to_remove] + + def check_stop(request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None) -> bool: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8f310023a8cd3..f0af92122958c 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -3,14 +3,14 @@ import itertools from abc import ABC, abstractmethod from collections import defaultdict -from typing import Callable from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + CrossAttentionSpec, FullAttentionSpec, + KVCacheSpec, MambaSpec, + SlidingWindowSpec) from vllm.v1.request import Request @@ -25,7 +25,6 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: KVCacheSpec, block_pool: BlockPool, kv_cache_group_id: int, - caching_hash_fn: Callable, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -33,7 +32,6 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. - caching_hash_fn: The caching hash function. """ self.block_size = kv_cache_spec.block_size @@ -52,7 +50,6 @@ class SingleTypeKVCacheManager(ABC): # data for reempted ones. self.num_cached_block: dict[str, int] = {} - self.caching_hash_fn = caching_hash_fn self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block @@ -130,14 +127,12 @@ class SingleTypeKVCacheManager(ABC): req_blocks.extend(new_blocks) return new_blocks - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_tokens: int) -> None: + def cache_blocks(self, request: Request, num_tokens: int) -> None: """ Cache the blocks for the request. Args: request: The request. - block_hashes: The block hashes of the request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ @@ -147,12 +142,10 @@ class SingleTypeKVCacheManager(ABC): self.block_pool.cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], - block_hashes=block_hashes, num_cached_blocks=num_cached_blocks, num_full_blocks=num_full_blocks, block_size=self.block_size, kv_cache_group_id=self.kv_cache_group_id, - hash_fn=self.caching_hash_fn, ) self.num_cached_block[request.request_id] = num_full_blocks @@ -560,11 +553,62 @@ class MambaManager(SingleTypeKVCacheManager): return new_blocks +class CrossAttentionManager(SingleTypeKVCacheManager): + """Manager for cross-attention KV cache in encoder-decoder models.""" + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so `new_computed_blocks` should always be empty. + assert len(new_computed_blocks) == 0 + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so this method is not relevant. + raise ValueError("Should not be called as prefix caching is disabled.") + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + # Cross-attention blocks contain request-specific encoder states + # and are not shared between different requests + return 0 + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> tuple[list[KVCacheBlock], ...]: + assert isinstance(kv_cache_spec, CrossAttentionSpec), ( + "CrossAttentionManager can only be used for cross-attention groups" + ) + # Cross-attention does not benefit from prefix caching since: + # 1. Encoder states are unique per request (different audio/image + # inputs) + # 2. Encoder states are computed once per request, not incrementally + # 3. No reusable prefix exists between different multimodal inputs + # Return empty blocks to indicate no cache hits + raise NotImplementedError( + "CrossAttentionManager does not support caching") + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # Cross-attention blocks represent encoder states which are needed + # for the entire decoding process, so no blocks should be skipped + pass + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, + CrossAttentionSpec: CrossAttentionManager, } diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index b29394f3e6760..f7ec982db41b4 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Sequence from typing import Any, Optional, Union import msgspec @@ -47,7 +48,7 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] - mm_kwargs: Optional[list[MultiModalKwargsItem]] + mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] sampling_params: Optional[SamplingParams] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index edc2e235c3c3f..dbea0b610b31a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import os +import socket import time -from collections.abc import AsyncGenerator, Mapping +from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy from typing import Any, Optional, Union import numpy as np +import torch import vllm.envs as envs from vllm.config import ModelConfig, VllmConfig @@ -27,7 +30,8 @@ from vllm.transformers_utils.config import ( from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, cancel_task_threadsafe, cdiv, deprecate_kwargs +from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv, + deprecate_kwargs) from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError @@ -143,6 +147,26 @@ class AsyncLLM(EngineClient): except RuntimeError: pass + if envs.VLLM_TORCH_PROFILER_DIR: + logger.info( + "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 + envs.VLLM_TORCH_PROFILER_DIR) + worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + envs.VLLM_TORCH_PROFILER_DIR, + worker_name=worker_name, + use_gzip=True)) + else: + logger.info( + "Torch profiler disabled. AsyncLLM CPU traces will not be collected." # noqa: E501 + ) + self.profiler = None + @classmethod @deprecate_kwargs( "disable_log_requests", @@ -431,14 +455,16 @@ class AsyncLLM(EngineClient): self.output_handler = asyncio.create_task(output_handler()) - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" - request_ids = self.output_processor.abort_requests((request_id, )) - await self.engine_core.abort_requests_async(request_ids) + request_ids = (request_id, ) if isinstance( + request_id, str) else as_list(request_id) + all_request_ids = self.output_processor.abort_requests(request_ids) + await self.engine_core.abort_requests_async(all_request_ids) if self.log_requests: - logger.info("Aborted request %s.", request_id) + logger.info("Aborted request(s) %s.", ",".join(request_ids)) async def encode( self, @@ -559,14 +585,19 @@ class AsyncLLM(EngineClient): raise self.dead_error async def start_profile(self) -> None: - await self.engine_core.profile_async(True) + coros = [self.engine_core.profile_async(True)] + if self.profiler is not None: + coros.append(asyncio.to_thread(self.profiler.start)) + await asyncio.gather(*coros) async def stop_profile(self) -> None: - await self.engine_core.profile_async(False) + coros = [self.engine_core.profile_async(False)] + if self.profiler is not None: + coros.append(asyncio.to_thread(self.profiler.stop)) + await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ed426f8ff452b..a7038e2d2c264 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -22,12 +22,15 @@ from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (decorate_logs, make_zmq_socket, +from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, resolve_obj_by_qualname, set_process_title) -from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, +from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config, + get_request_block_hasher, + init_none_hash, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput @@ -36,8 +39,8 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer -from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses +from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses, + get_device_indices) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -124,9 +127,11 @@ class EngineCore: > 1, log_stats=self.log_stats, ) + self.use_spec_decode = vllm_config.speculative_config is not None - self.mm_input_cache_server = MultiModalInputCacheServer( - vllm_config.model_config, MULTIMODAL_REGISTRY) + self.mm_registry = mm_registry = MULTIMODAL_REGISTRY + self.mm_receiver_cache = receiver_cache_from_config( + vllm_config, mm_registry) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously @@ -140,6 +145,19 @@ class EngineCore: self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) + self.request_block_hasher: Optional[Callable[[Request], + list[BlockHash]]] = None + if (self.vllm_config.cache_config.enable_prefix_caching + or self.scheduler.get_kv_connector() is not None): + + block_size = vllm_config.cache_config.block_size + caching_hash_fn = get_hash_fn_by_name( + vllm_config.cache_config.prefix_caching_hash_algo) + init_none_hash(caching_hash_fn) + + self.request_block_hasher = get_request_block_hasher( + block_size, caching_hash_fn) + def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() @@ -279,6 +297,13 @@ class EngineCore: return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) + def post_step(self, model_executed: bool) -> None: + if self.use_spec_decode and model_executed: + # Take the draft token ids. + draft_token_ids = self.model_executor.take_draft_token_ids() + if draft_token_ids is not None: + self.scheduler.update_draft_token_ids(draft_token_ids) + def step_with_batch_queue( self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. @@ -347,7 +372,8 @@ class EngineCore: logger.warning("Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches.") - self.mm_input_cache_server.reset() + if self.mm_receiver_cache is not None: + self.mm_receiver_cache.clear_cache() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -412,12 +438,14 @@ class EngineCore: assert request.mm_kwargs is not None # Note on thread safety: no race condition. - # `mm_input_cache_server` is reset at the end of LLMEngine init, + # `mm_receiver_cache` is reset at the end of LLMEngine init, # and will only accessed in the input processing thread afterwards. - request.mm_kwargs = self.mm_input_cache_server.get_and_update( - request.mm_kwargs, request.mm_hashes) + if self.mm_receiver_cache is not None: + request.mm_kwargs = self.mm_receiver_cache.get_and_update( + request.mm_kwargs, request.mm_hashes) - req = Request.from_engine_core_request(request) + req = Request.from_engine_core_request(request, + self.request_block_hasher) if req.use_structured_output: # Note on thread safety: no race condition. # `grammar_init` is only invoked in input processing thread. For @@ -730,6 +758,8 @@ class EngineCoreProc(EngineCore): # Put EngineCoreOutputs into the output queue. for output in (outputs.items() if outputs else ()): self.output_queue.put_nowait(output) + # Post-step hook. + self.post_step(model_executed) return model_executed @@ -1140,22 +1170,30 @@ class DPEngineCoreActor(DPEngineCoreProc): # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501 # and get_accelerator_ids_for_accelerator_resource() in worker.py # of ray. - self._set_cuda_visible_devices(vllm_config, local_dp_rank) + self._set_visible_devices(vllm_config, local_dp_rank) super().__init__(vllm_config, local_client, "", executor_class, log_stats) - def _set_cuda_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int): + def _set_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int): from vllm.platforms import current_platform - device_control_env_var = current_platform.device_control_env_var + if current_platform.is_xpu(): + pass + else: + device_control_env_var = current_platform.device_control_env_var + self._set_cuda_visible_devices(vllm_config, local_dp_rank, + device_control_env_var) + + def _set_cuda_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int, + device_control_env_var: str): world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. try: - os.environ[device_control_env_var] = ",".join( - str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * - world_size, (local_dp_rank + 1) * world_size)) + value = get_device_indices(device_control_env_var, local_dp_rank, + world_size) + os.environ[device_control_env_var] = value except IndexError as e: raise Exception( f"Error setting {device_control_env_var}: " diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 29ee0a9dfb1e2..65f7abc97110c 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -574,13 +574,22 @@ class MPClient(EngineCoreClient): def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): - """Set the result from a utility method in the waiting future""" + """Set the result from a utility method in the waiting future.""" future = utility_results.pop(output.call_id) - if output.failure_message is not None: - future.set_exception(Exception(output.failure_message)) - else: - assert output.result is not None - future.set_result(output.result.result) + failure_message = output.failure_message + try: + if failure_message is not None: + future.set_exception(Exception(failure_message)) + else: + assert output.result is not None + future.set_result(output.result.result) + except asyncio.InvalidStateError: + # This can happen if the future is cancelled due to the + # original calling task being cancelled. + if failure_message is not None: + logger.error( + "Cancelled call to utility method failed " + "with error: %s", failure_message) class SyncMPClient(MPClient): @@ -1181,21 +1190,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient): await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) - async def _send_reconfig_message( - self, reconfig_request: ReconfigureDistributedRequest, - engine: EngineIdentity) -> asyncio.Future: - """Send reconfiguration message and return the result future without - waiting for completion.""" - call_id = uuid.uuid1().int >> 64 - future = asyncio.get_running_loop().create_future() - self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (self.client_index, call_id, "reinitialize_distributed", - (reconfig_request, )))) - await self._send_input_message(message, engine, reconfig_request) - self._ensure_output_queue_task() - return future - async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: """Scale elastic EP data parallel size""" cur_data_parallel_size = len(self.core_engines) @@ -1205,7 +1199,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient): f"different from cur_data_parallel_size {cur_data_parallel_size}") assert self.vllm_config.parallel_config.data_parallel_backend == \ - "ray", ("Only ray DP backend supports scaling elastic EP") + "ray", "Only ray DP backend supports scaling elastic EP" scale_up = new_data_parallel_size > cur_data_parallel_size @@ -1237,9 +1231,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): data_parallel_master_ip, new_data_parallel_master_port=self.vllm_config.parallel_config. data_parallel_master_port) - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + coro = self._call_utility_async("reinitialize_distributed", + reconfig_request, + engine=engine) + reconfig_futures.append(asyncio.create_task(coro)) logger.info("All reconfigure messages sent, starting engine creation") @@ -1309,9 +1304,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = \ ReconfigureRankType.SHUTDOWN_CURRENT_RANK - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + coro = self._call_utility_async("reinitialize_distributed", + reconfig_request, + engine=engine) + reconfig_futures.append(asyncio.create_task(coro)) for _ in range(new_data_parallel_size, cur_data_parallel_size): self.core_engines.pop() diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 2f5504ea14b41..04ad51aae0a8c 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -74,6 +74,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): params = request.sampling_params assert params is not None self.stop = stop = params.stop + self.min_tokens = params.min_tokens self.include_stop_str_in_output = params.include_stop_str_in_output # Number of chars to hold back when stop strings are to be excluded @@ -111,10 +112,14 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. - offset_before = len(self.output_text) + stop_check_offset = len(self.output_text) for new_token_id in new_token_ids: self.token_ids.append(new_token_id) self.output_text += self.decode_next(new_token_id) + # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014 + if self.min_tokens and len( + self.output_token_ids) <= self.min_tokens: + stop_check_offset = len(self.output_text) if stop_terminated: if skipped_stop_token_id is not None: @@ -125,10 +130,10 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # 2) Evaluate stop strings. stop_string = None - if self.stop: + if self.stop and len(self.output_token_ids) > self.min_tokens: stop = StopChecker.check_stop_strings( output_text=self.output_text, - new_char_count=len(self.output_text) - offset_before, + new_char_count=len(self.output_text) - stop_check_offset, stop=self.stop, include_in_output=self.include_stop_str_in_output, ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5a00a930951cc..7130f666ef19f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -271,8 +271,7 @@ class LLMEngine: self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py deleted file mode 100644 index 1fed74330f0ec..0000000000000 --- a/vllm/v1/engine/mm_input_cache.py +++ /dev/null @@ -1,118 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Mapping -from typing import TYPE_CHECKING - -from vllm.multimodal import MultiModalRegistry -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -# The idea of multimodal input caching is based on having a client and -# a server, where the client executes in the frontend process (=P0) and the -# server in the core process (=P1). -# -# -- P0: -# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of -# each input multi-modal item (e.g. image), -# - BaseMultiModalProcessor processes the input items into `mm_kwargs`, -# which are MultiModalKwargsItem instances that each correspond to an -# input multi-modal item. -# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding -# `mm_hash` for each item. It stores the `mm_hash` as keys and the size -# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking -# up additional memory in P0. -# - The `mm_hash` is always sent to P1. -# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached -# in MultiModalInputCacheServer. -# -# -- P1: -# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0), -# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`. -# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0), -# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`. -# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to -# the engine for model execution. -# -# Both Client and Server must perform cache update and eviction based on the -# same item size. This ensures that the keys of MultiModalInputCacheClient -# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0 -# whether a key is cached in MultiModalInputCacheServer by querying -# MultiModalInputCacheClient without having to communicate with P1. - - -class MultiModalInputCacheClient: - """Used by P0 to check whether multi-modal kwargs are cached in P1.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalCacheItemMetadata, - ) - - def get_and_update( - self, - mm_kwargs: list[MultiModalKwargsItem], - mm_hashes: list[str], - ) -> list[MultiModalKwargsItem]: - if not self.enabled: - return mm_kwargs - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[MultiModalKwargsItem]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if self.mm_cache.get(mm_hash) is not None: - out_mm_items.append(mm_item.without_data()) - else: - self.mm_cache[mm_hash] = \ - MultiModalCacheItemMetadata.wraps(mm_item.require_data()) - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() - - -class MultiModalInputCacheServer: - """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - Mapping[str, NestedTensors], - ) - - def get_and_update( - self, - mm_kwargs: list[MultiModalKwargsItem], - mm_hashes: list[str], - ) -> list[MultiModalKwargsItem]: - if not self.enabled: - return mm_kwargs - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[MultiModalKwargsItem]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if (mm_data := mm_item.get_data()) is None: - out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash])) - else: - self.mm_cache[mm_hash] = mm_data - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 376c76a7e7285..7ed60156626bf 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions @@ -18,9 +19,10 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) +from vllm.v1.structured_output.backend_lm_format_enforcer import ( + validate_structured_output_request_lm_format_enforcer) from vllm.v1.structured_output.backend_outlines import ( validate_structured_output_request_outlines) from vllm.v1.structured_output.backend_xgrammar import ( @@ -45,16 +47,17 @@ class Processor: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) - self.mm_input_cache_client = MultiModalInputCacheClient( - self.model_config, mm_registry) + self.mm_registry = mm_registry + self.mm_processor_cache = processor_cache_from_config( + vllm_config, mm_registry) - @property - def mm_registry(self): - return self.input_preprocessor.mm_registry + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=self.mm_processor_cache, + ) def _validate_logprobs( self, @@ -200,6 +203,9 @@ class Processor: elif engine_level_backend == "outlines": # outlines backend validate_structured_output_request_outlines(params) + elif engine_level_backend == "lm-format-enforcer": + # lm format enforcer backend + validate_structured_output_request_lm_format_enforcer(params) else: # NOTE: engine_level_backend must be "auto" here, because we have # checked supported_backends above. @@ -252,13 +258,10 @@ class Processor: # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. - return_mm_hashes = (self.model_config.processor_return_mm_hashes - or bool(self.cache_config.enable_prefix_caching)) processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) from vllm.platforms import current_platform current_platform.validate_request( @@ -295,13 +298,13 @@ class Processor: pooling_params = params.clone() # Multimodal related. - sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None + sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_hashes: Optional[list[str]] = None if decoder_inputs["type"] == "multimodal": decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_positions = decoder_inputs["mm_placeholders"] - decoder_mm_hashes = decoder_inputs.get("mm_hashes") + decoder_mm_hashes = decoder_inputs["mm_hashes"] # Merge and flatten multimodal placeholders, hashes and inputs # from dictionaries to lists, and sort them by each item's position @@ -309,24 +312,18 @@ class Processor: sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) sorted_mm_inputs = [ - decoder_mm_inputs.get_item(modality, idx) + decoder_mm_inputs[modality][idx] for modality, idx in sorted_mm_idxs ] sorted_mm_positions = [ decoder_mm_positions[modality][idx] for modality, idx in sorted_mm_idxs ] - sorted_mm_hashes = None if decoder_mm_hashes is None else [ + sorted_mm_hashes = [ decoder_mm_hashes[modality][idx] for modality, idx in sorted_mm_idxs ] - if sorted_mm_hashes is not None: - sorted_mm_inputs = self.mm_input_cache_client.get_and_update( - sorted_mm_inputs, - sorted_mm_hashes, - ) - return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], @@ -393,7 +390,7 @@ class Processor: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( @@ -414,3 +411,6 @@ class Processor: # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def clear_cache(self) -> None: + self.input_preprocessor.clear_cache() diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 770aa7d9dcc8a..56ef8477d267a 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -71,7 +71,7 @@ class EngineHandshakeMetadata: connect to. """ addresses: EngineZmqAddresses - parallel_config: dict[str, Union[int, str]] + parallel_config: dict[str, Union[int, str, list[int]]] class CoreEngineProcManager: @@ -164,19 +164,33 @@ def set_device_control_env_var(vllm_config: VllmConfig, """ world_size = vllm_config.parallel_config.world_size evar = current_platform.device_control_env_var + + value = get_device_indices(evar, local_dp_rank, world_size) + with patch.dict(os.environ, values=((evar, value), )): + yield + + +def get_device_indices(device_control_env_var: str, local_dp_rank: int, + world_size: int): + """ + Returns a comma-separated string of device indices for the specified + data parallel rank. + + For example, if world_size=2 and local_dp_rank=1, and there are 4 devices, + this will select devices 2 and 3 for local_dp_rank=1. + """ try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)) except IndexError as e: - raise Exception(f"Error setting {evar}: " + raise Exception(f"Error setting {device_control_env_var}: " f"local range: [{local_dp_rank * world_size}, " f"{(local_dp_rank + 1) * world_size}) " "base value: " - f"\"{os.getenv(evar)}\"") from e - with patch.dict(os.environ, values=((evar, value), )): - yield + f"\"{os.getenv(device_control_env_var)}\"") from e + return value class CoreEngineActorManager: @@ -254,6 +268,19 @@ class CoreEngineActorManager: dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count + + # Ray XPU known issue: dpctl initializes the GPU runtime early, so + # setting device env vars in Ray actor's initialization method + # will not affect device selection. See: + # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501 + if current_platform.is_xpu(): + device_evar = current_platform.device_control_env_var + device_indices = get_device_indices(device_evar, local_index, + world_size) + actor_env_vars = self.env_vars_dict.copy() + actor_env_vars[device_evar] = device_indices + runtime_env = RuntimeEnv(env_vars=actor_env_vars) + actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, @@ -798,6 +825,8 @@ def wait_for_engine_startup( parallel_config.data_parallel_master_ip, "data_parallel_master_port": parallel_config.data_parallel_master_port, + "_data_parallel_master_port_list": + parallel_config._data_parallel_master_port_list, "data_parallel_size": parallel_config.data_parallel_size, })) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 50b9634a49e1b..4be2f74177b1f 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Union +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -13,8 +13,9 @@ from vllm.executor.uniproc_executor import ( # noqa ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) from vllm.executor.uniproc_executor import ( # noqa UniProcExecutor as UniProcExecutorV0) +from vllm.utils import resolve_obj_by_qualname from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput FailureCallback = Callable[[], None] @@ -50,6 +51,13 @@ class Executor(ExecutorBase): # TODO: make v1 scheduling deterministic # to support external launcher executor_class = ExecutorWithExternalLauncher + elif isinstance(distributed_executor_backend, str): + executor_class = resolve_obj_by_qualname( + distributed_executor_backend) + if not issubclass(executor_class, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {executor_class}.") else: raise ValueError("Unknown distributed executor backend: " f"{distributed_executor_backend}") @@ -88,6 +96,10 @@ class Executor(ExecutorBase): args=(scheduler_output, )) return output[0] + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + output = self.collective_rpc("take_draft_token_ids") + return output[0] + @property def max_concurrent_batches(self) -> int: return 1 diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0db3bcd7fb408..15b88a2128994 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method, get_loopback_ip, get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -191,6 +191,12 @@ class MultiprocExecutor(Executor): outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + # OPTIMIZATION: Get output only from a single worker (output_rank) + outputs = self.collective_rpc("take_draft_token_ids", + unique_reply_rank=self.output_rank) + return outputs[0] + def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 429416afa2483..a3e4d393e4d20 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,6 +11,7 @@ from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.utils import cdiv, get_dtype_size logger = init_logger(__name__) @@ -203,6 +204,28 @@ class MambaSpec(KVCacheSpec): return self.page_size_bytes +@dataclass(frozen=True) +class EncoderOnlyAttentionSpec(AttentionSpec): + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # Encoder-only layers do not need KV cache + return 0 + + +@dataclass(frozen=True) +class CrossAttentionSpec(AttentionSpec): + """ + KV cache spec for cross-attention layers in encoder-decoder models. + """ + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # For cross-attention, we need to cache encoder states + # Get encoder length (e.g., 1500 for Whisper). + max_encoder_len = MULTIMODAL_REGISTRY.\ + get_encdec_max_encoder_len(vllm_config.model_config) + return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 7d7cd0c94dd04..f8d6b24702f3c 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -94,9 +94,6 @@ class ModelRunnerOutput: # each request due to speculative/jump decoding. sampled_token_ids: list[list[int]] - # num_reqs x num_spec_tokens - spec_token_ids: Optional[list[list[int]]] - # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] @@ -117,10 +114,18 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +@dataclass +class DraftTokenIds: + + # [num_reqs] + req_ids: list[str] + # num_reqs x num_draft_tokens + draft_token_ids: list[list[int]] + + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, sampled_token_ids=[], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 28af720d05fd1..46506d272e90a 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -6,15 +6,40 @@ from typing import Optional import torch from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +@dataclass +class PoolingCursor: + index: list[int] + first_token_indices_gpu: torch.Tensor + last_token_indices_gpu: torch.Tensor + prompt_lens_cpu: torch.Tensor + num_scheduled_tokens_cpu: torch.Tensor + + def __getitem__(self, indices: slice): + return PoolingCursor( + index=self.index[indices], + first_token_indices_gpu=self.first_token_indices_gpu[indices], + last_token_indices_gpu=self.last_token_indices_gpu[indices], + prompt_lens_cpu=self.prompt_lens_cpu[indices], + num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], + ) + + def is_partial_prefill(self): + return not torch.all( + self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) @dataclass class PoolingMetadata: """Tensors for pooling.""" - - prompt_lens: torch.Tensor + prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] + pooling_cursor: Optional[PoolingCursor] = None def __getitem__(self, indices: slice): return PoolingMetadata( @@ -22,4 +47,31 @@ class PoolingMetadata: prompt_token_ids=None if self.prompt_token_ids is None else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], + pooling_cursor=None + if self.pooling_cursor is None else self.pooling_cursor[indices], ) + + def build_pooling_cursor(self, num_scheduled_tokens: list[int], + device: torch.device): + self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, + self.prompt_lens, device) + + +def build_pooling_cursor(num_scheduled_tokens: list[int], + prompt_lens: torch.Tensor, device: torch.device): + assert len(prompt_lens) == len(num_scheduled_tokens) + + n_seq = len(num_scheduled_tokens) + index = list(range(n_seq)) + num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu") + cumsum = torch.zeros(n_seq + 1, + dtype=torch.int64, + pin_memory=pin_memory, + device="cpu") + torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:]) + cumsum = cumsum.to(device, non_blocking=True) + return PoolingCursor(index=index, + first_token_indices_gpu=cumsum[:n_seq], + last_token_indices_gpu=cumsum[1:] - 1, + prompt_lens_cpu=prompt_lens, + num_scheduled_tokens_cpu=num_scheduled_tokens) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d1f1c7f98755f..4e99a9ccef46e 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,7 +3,8 @@ import enum import time -from typing import TYPE_CHECKING, Any, Optional, Union +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.pooling_params import PoolingParams @@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: from vllm.lora.request import LoRARequest + from vllm.v1.core.kv_cache_utils import BlockHash class Request: @@ -36,6 +38,8 @@ class Request: structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, + block_hasher: Optional[Callable[["Request"], + list["BlockHash"]]] = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -50,8 +54,7 @@ class Request: time.time() self.status = RequestStatus.WAITING - if sampling_params and sampling_params.guided_decoding is not None: - self.status = RequestStatus.WAITING_FOR_FSM + self.use_structured_output = False self.events: list[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None @@ -59,12 +62,15 @@ class Request: self.kv_transfer_params: Optional[dict[str, Any]] = None if pooling_params is not None: + # Pooling models. self.max_tokens = 1 elif sampling_params is not None: + # Generative models. assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens if sampling_params.guided_decoding is not None: self.status = RequestStatus.WAITING_FOR_FSM + self.use_structured_output = True if sampling_params.extra_args is not None: self.kv_transfer_params = \ @@ -108,17 +114,30 @@ class Request: # indicates that the output is corrupted self.num_nans_in_logits = 0 + self.block_hashes: list[BlockHash] = [] + self.get_hash_new_full_blocks: Optional[Callable[ + [], list[BlockHash]]] = None + if block_hasher is not None: + self.get_hash_new_full_blocks = partial(block_hasher, self) + self.block_hashes = self.get_hash_new_full_blocks() + @classmethod - def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + def from_engine_core_request( + cls, request: EngineCoreRequest, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] + ) -> "Request": if request.mm_kwargs is not None: - assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), ( + mm_kwargs_lst = list(request.mm_kwargs) + assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), ( "mm_kwargs was not updated in EngineCore.add_request") + else: + mm_kwargs_lst = None return cls( request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, - multi_modal_kwargs=request.mm_kwargs, + multi_modal_kwargs=mm_kwargs_lst, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, @@ -131,6 +150,7 @@ class Request: if request.sampling_params else None, cache_salt=request.cache_salt, priority=request.priority, + block_hasher=block_hasher, ) def append_output_token_ids( @@ -144,6 +164,9 @@ class Request: self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) + if self.get_hash_new_full_blocks is not None: + self.block_hashes.extend(self.get_hash_new_full_blocks()) + @property def is_output_corrupted(self) -> bool: return self.num_nans_in_logits > 0 @@ -171,11 +194,6 @@ class Request: num_tokens = self.mm_positions[input_id].length return num_tokens - @property - def use_structured_output(self) -> bool: - return self.sampling_params is not None and \ - self.sampling_params.guided_decoding is not None - def record_event( self, event_type: EngineCoreEventType, diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py new file mode 100644 index 0000000000000..8220269162951 --- /dev/null +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +import itertools +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor) +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) +from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder, + LogitsProcessors) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +# Error message when the user tries to initialize vLLM with a pooling model +# and custom logitsproces +STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" + " logits processors.") + +LOGITSPROCS_GROUP = 'vllm.logits_processors' + +BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ + MinTokensLogitsProcessor, + LogitBiasLogitsProcessor, + MinPLogitsProcessor, +] + + +def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: + """Load all installed logit processor plugins""" + + import sys + + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) + if len(installed_logitsprocs_plugins) == 0: + logger.debug("No logitsprocs plugins installed (group %s).", + LOGITSPROCS_GROUP) + return [] + + # Load logitsprocs plugins + logger.debug("Loading installed logitsprocs plugins (group %s):", + LOGITSPROCS_GROUP) + classes: list[type[LogitsProcessor]] = [] + for entrypoint in installed_logitsprocs_plugins: + try: + logger.debug("- Loading logitproc plugin entrypoint=%s target=%s", + entrypoint.name, entrypoint.value) + classes.append(entrypoint.load()) + except Exception as e: + raise RuntimeError( + f"Failed to load LogitsProcessor plugin {entrypoint}") from e + return classes + + +def _load_logitsprocs_by_fqcns( + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]] +) -> list[type[LogitsProcessor]]: + """Load logit processor types, identifying them by fully-qualified class + names (FQCNs). + + Effectively, a mixed list of logitproc types and FQCN strings is converted + into a list of entirely logitproc types, by loading from the FQCNs. + + FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc + + Already-loaded logitproc types must be subclasses of LogitsProcessor + + Args: + logits_processors: Potentially mixed list of logitsprocs types and FQCN + strings for logitproc types + + Returns: + List of logitproc types + + """ + if not logits_processors: + return [] + + logger.debug( + "%s additional custom logits processors specified, checking whether " + "they need to be loaded.", len(logits_processors)) + + classes: list[type[LogitsProcessor]] = [] + for ldx, logitproc in enumerate(logits_processors): + if isinstance(logitproc, type): + logger.debug(" - Already-loaded logit processor: %s", + logitproc.__name__) + if not issubclass(logitproc, LogitsProcessor): + raise ValueError( + f"{logitproc.__name__} is not a subclass of LogitsProcessor" + ) + classes.append(logitproc) + continue + + logger.debug("- Loading logits processor %s", logitproc) + module_path, qualname = logitproc.split(":") + + try: + # Load module + module = importlib.import_module(module_path) + except Exception as e: + raise RuntimeError( + f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}" + ) from e + + # Walk down dotted name to get logitproc class + obj = module + for attr in qualname.split("."): + obj = getattr(obj, attr) + if not isinstance(obj, type): + raise ValueError("Loaded logit processor must be a type.") + if not issubclass(obj, LogitsProcessor): + raise ValueError( + f"{obj.__name__} must be a subclass of LogitsProcessor") + classes.append(obj) + + return classes + + +def _load_custom_logitsprocs( + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], +) -> list[type[LogitsProcessor]]: + """Load all custom logits processors. + + * First load all installed logitproc plugins + * Second load custom logitsprocs pass by the user at initialization time + + Args: + logits_processors: potentially mixed list of logitproc types and + logitproc type fully-qualified names (FQCNs) + which need to be loaded + + Returns: + A list of all loaded logitproc types + """ + from vllm.platforms import current_platform + if current_platform.is_tpu(): + # No logitsprocs specified by caller + # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs + return [] + + return (_load_logitsprocs_plugins() + + _load_logitsprocs_by_fqcns(logits_processors)) + + +def build_logitsprocs( + vllm_config: "VllmConfig", + device: torch.device, + is_pin_memory: bool, + is_pooling_model: bool, + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), +) -> LogitsProcessors: + if is_pooling_model: + if custom_logitsprocs: + raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) + logger.debug("Skipping logits processor loading because pooling models" + " do not support logits processors.") + return LogitsProcessors() + custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) + return LogitsProcessors( + ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) + + +__all__ = [ + "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", + "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", + "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", + "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP" +] diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor/builtin.py similarity index 51% rename from vllm/v1/sample/logits_processor.py rename to vllm/v1/sample/logits_processor/builtin.py index 3a06e71057cdd..00dd757489ca0 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,241 +1,32 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses -from abc import ABC, abstractmethod -from collections.abc import Iterator, Sequence -from dataclasses import dataclass, field -from enum import Enum -from itertools import chain -from typing import Optional, Union +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional import torch -from torch._prims_common import DeviceLikeType -from vllm import PoolingParams, SamplingParams -from vllm.logger import init_logger +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) -logger = init_logger(__name__) - - -class MoveDirectionality(Enum): - # One-way i1->i2 req move within batch - UNIDIRECTIONAL = 0 - # Two-way i1<->i2 req swap within batch - SWAP = 1 - - -# (index, params, output_tok_ids) tuples for new -# requests added to the batch. -AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]] -# (index 1, index 2, directionality) tuples representing -# one-way moves or two-way swaps of requests in batch -MovedRequest = tuple[int, int, MoveDirectionality] -# Batch indices of any removed requests. -RemovedRequest = int - - -@dataclasses.dataclass(frozen=True) -class BatchUpdate: - """Persistent batch state change info for logitsprocs""" - batch_size: int # Current num reqs in batch - - # Metadata for requests added to, removed from, and moved - # within the persistent batch. - # - # Note: each added request is represented as - # (index, params, output_tok_ids) - # Key assumption: output_tok_ids is a reference to the - # request's running output tokens list; in this way - # the logits processors always see the latest list of - # generated tokens - removed: Sequence[RemovedRequest] - moved: Sequence[MovedRequest] - added: Sequence[AddedRequest] - - -class BatchUpdateBuilder: - """Helps track persistent batch state changes and build - a batch update data structure for logitsprocs - - Assumptions: - * All information about requests removed from persistent batch - during a step is aggregated in self._removed through calls to - self.removed_append() at the beginning of a step. This must happen - before the first time that self.removed, self.pop_removed() - or self.peek_removed() are invoked in a given step - * After the first time that self.removed, self.pop_removed() - or self.peek_removed() are read in a step, no new removals - are registered using self.removed_append() - * Elements of self._removed are never directly modified, added or - removed (i.e. modification is only via self.removed_append() and - self.pop_removed()) - - Guarantees under above assumptions: - * self.removed is always sorted in descending order - * self.pop_removed() and self.peek_removed() both return - the lowest removed request index in the current step - """ - - _removed: list[RemovedRequest] - _is_removed_sorted: bool - moved: list[MovedRequest] - added: list[AddedRequest] - - def __init__( - self, - removed: Optional[list[RemovedRequest]] = None, - moved: Optional[list[MovedRequest]] = None, - added: Optional[list[AddedRequest]] = None, - ) -> None: - self._removed = removed or [] - self.moved = moved or [] - self.added = added or [] - self._is_removed_sorted = False - - def _ensure_removed_sorted(self) -> None: - """Sort removed request indices in - descending order. - - Idempotent after first call in a - given step, until reset. - """ - if not self._is_removed_sorted: - self._removed.sort(reverse=True) - self._is_removed_sorted = True - - @property - def removed(self) -> list[RemovedRequest]: - """Removed request indices sorted in - descending order""" - self._ensure_removed_sorted() - return self._removed - - def removed_append(self, index: int) -> None: - """Register the removal of a request from - the persistent batch. - - Must not be called after the first time - self.removed, self.pop_removed() or - self.peek_removed() are invoked. - - Args: - index: request index - """ - if self._is_removed_sorted: - raise RuntimeError("Cannot register new removed request after" - " self.removed has been read.") - self._removed.append(index) - - def has_removed(self) -> bool: - return bool(self._removed) - - def peek_removed(self) -> Optional[int]: - """Return lowest removed request index""" - if self.has_removed(): - self._ensure_removed_sorted() - return self._removed[-1] - return None - - def pop_removed(self) -> Optional[int]: - """Pop lowest removed request index""" - if self.has_removed(): - self._ensure_removed_sorted() - return self._removed.pop() - return None - - def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: - """Generate a logitsprocs batch update data structure - and reset internal batch update builder state. - - Args: - batch_size: current persistent batch size - - Returns: - Frozen logitsprocs batch update instance; `None` if no updates - """ - # Reset removal-sorting logic - self._is_removed_sorted = False - if not any((self._removed, self.moved, self.added)): - # No update; short-circuit - return None - # Build batch state update - batch_update = BatchUpdate( - batch_size=batch_size, - removed=self._removed, - moved=self.moved, - added=self.added, - ) - # Reset removed/moved/added update lists - self._removed = [] - self.moved = [] - self.added = [] - return batch_update - - -class LogitsProcessor(ABC): - - @abstractmethod - def apply(self, logits: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - @abstractmethod - def is_argmax_invariant(self) -> bool: - """True if logits processor has no impact on the - argmax computation in greedy sampling. - NOTE: may or may not have the same value for all - instances of a given LogitsProcessor subclass, - depending on subclass implementation. - TODO(andy): won't be utilized until logits - processors are user-extensible - """ - raise NotImplementedError - - @abstractmethod - def update_state( - self, - batch_update: Optional[BatchUpdate], - ) -> None: - """Called when there are new output tokens, prior - to each forward pass. - - Args: - batch_update is non-None iff there have been - changes to the batch makeup. - """ - raise NotImplementedError - - -@dataclass -class LogitsProcessorManager: - """Encapsulates initialized logitsproc objects.""" - argmax_invariant: list[LogitsProcessor] = field( - default_factory=list) # argmax-invariant logitsprocs - non_argmax_invariant: list[LogitsProcessor] = field( - default_factory=list) # non-argmax-invariant logitsprocs - - @property - def all(self) -> Iterator[LogitsProcessor]: - """Iterator over all logits processors.""" - return chain(self.argmax_invariant, self.non_argmax_invariant) - - -###### ----- Built-in LogitsProcessor impls below here +if TYPE_CHECKING: + from vllm.config import VllmConfig class MinPLogitsProcessor(LogitsProcessor): - def __init__(self, max_num_reqs: int, pin_memory: bool, - device: DeviceLikeType): - super().__init__() + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.min_p_count: int = 0 self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), dtype=torch.float32, device="cpu", - pin_memory=pin_memory) + pin_memory=is_pin_memory) self.min_p_cpu = self.min_p_cpu_tensor.numpy() - self.use_double_tensor = torch.device("cpu") != torch.device(device) + self.use_double_tensor = torch.device(device).type != "cpu" if self.use_double_tensor: # Pre-allocated device tensor @@ -260,31 +51,39 @@ class MinPLogitsProcessor(LogitsProcessor): needs_update = False # Process added requests. - for index, params, _ in batch_update.added: - min_p = params.min_p if isinstance(params, SamplingParams) else 0.0 - if self.min_p_cpu[index] != min_p: + for index, params, _, _ in batch_update.added: + min_p = params.min_p + min_p_before = self.min_p_cpu[index] + if min_p_before != min_p: needs_update = True self.min_p_cpu[index] = min_p - if min_p: - self.min_p_count += 1 + if min_p and not min_p_before: + self.min_p_count += 1 + elif not min_p and min_p_before: + self.min_p_count -= 1 if self.min_p_count: # Process removed requests. - needs_update |= bool(batch_update.removed) - for index in batch_update.removed: - if self.min_p_cpu[index]: - self.min_p_count -= 1 + if batch_update.removed: + needs_update = True + for index in batch_update.removed: + if self.min_p_cpu[index]: + self.min_p_cpu[index] = 0 + self.min_p_count -= 1 - # Process moved requests, unidirectional (a->b) and swap (a<->b) + # Process moved requests, unidirectional (a->b) and swap (a<->b). for adx, bdx, direct in batch_update.moved: - change = (min_p_a := - self.min_p_cpu[adx]) != (min_p_b := - self.min_p_cpu[bdx]) - needs_update |= change - if change: + min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx] + if min_p_a != min_p_b: + needs_update = True self.min_p_cpu[bdx] = min_p_a if direct == MoveDirectionality.SWAP: self.min_p_cpu[adx] = min_p_b + if direct == MoveDirectionality.UNIDIRECTIONAL: + if min_p_a: + self.min_p_cpu[adx] = 0 + if min_p_b: + self.min_p_count -= 1 # Update tensors if needed. size = batch_update.batch_size @@ -316,11 +115,10 @@ class MinPLogitsProcessor(LogitsProcessor): class LogitBiasLogitsProcessor(LogitsProcessor): - def __init__(self, pin_memory: bool, device: torch.device): - super().__init__() - self.biases: dict[int, dict[int, float]] = {} + def __init__(self, _, device: torch.device, is_pin_memory: bool): self.device = device - self.pin_memory = pin_memory + self.pin_memory = is_pin_memory + self.biases: dict[int, dict[int, float]] = {} self.bias_tensor: torch.Tensor = torch.tensor(()) self.logits_slice = (self._device_tensor([], torch.int32), @@ -337,9 +135,8 @@ class LogitBiasLogitsProcessor(LogitsProcessor): needs_update: bool = False # Process added requests. - for index, params, _ in batch_update.added: - if isinstance(params, SamplingParams) and (lb := - params.logit_bias): + for index, params, _, _ in batch_update.added: + if lb := params.logit_bias: self.biases[index] = lb needs_update = True else: @@ -400,12 +197,12 @@ class LogitBiasLogitsProcessor(LogitsProcessor): class MinTokensLogitsProcessor(LogitsProcessor): - def __init__(self, pin_memory: bool, device: torch.device): + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): # index -> (min_toks, output_token_ids, stop_token_ids) - super().__init__() - self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} self.device = device - self.pin_memory = pin_memory + self.pin_memory = is_pin_memory + self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} # (req_idx_tensor,eos_tok_id_tensor) self.logits_slice: tuple[torch.Tensor, @@ -424,9 +221,8 @@ class MinTokensLogitsProcessor(LogitsProcessor): if batch_update: # Process added requests. - for index, params, output_tok_ids in batch_update.added: - if (isinstance(params, SamplingParams) - and (min_tokens := params.min_tokens) + for index, params, _, output_tok_ids in batch_update.added: + if ((min_tokens := params.min_tokens) and len(output_tok_ids) < min_tokens): # Replace request metadata at batch index self.min_toks[index] = (min_tokens, output_tok_ids, @@ -499,35 +295,3 @@ class MinTokensLogitsProcessor(LogitsProcessor): # Inhibit EOS token for requests which have not reached min length logits[self.logits_slice] = -float("inf") return logits - - -def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, - device: torch.device) -> LogitsProcessorManager: - """Construct 'builtin' vLLM logitsprocs which the engine - loads by default. - - Args: - pin_memory_available: pinned memory is available for use - for use by logitsproc - max_num_reqs: ceiling on request count in persistent batch - device: inference device - - Returns: - Data structure encapsulating loaded logitsprocs - """ - min_tokens_logitproc = MinTokensLogitsProcessor( - pin_memory=pin_memory_available, device=device) - logit_bias_logitproc = LogitBiasLogitsProcessor( - pin_memory=pin_memory_available, device=device) - min_p_logitproc = MinPLogitsProcessor( - pin_memory=pin_memory_available, - device=device, - # +1 for temporary swap space - max_num_reqs=max_num_reqs + 1) - return LogitsProcessorManager( - non_argmax_invariant=[ - min_tokens_logitproc, - logit_bias_logitproc, - ], - argmax_invariant=[min_p_logitproc], - ) diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py new file mode 100644 index 0000000000000..12b4db24bff88 --- /dev/null +++ b/vllm/v1/sample/logits_processor/interface.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm import SamplingParams + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class MoveDirectionality(Enum): + # One-way i1->i2 req move within batch + UNIDIRECTIONAL = auto() + # Two-way i1<->i2 req swap within batch + SWAP = auto() + + +# (index, params, prompt_tok_ids, output_tok_ids) tuples for new +# requests added to the batch. +AddedRequest = tuple[int, SamplingParams, list[int], list[int]] + +# (index 1, index 2, directionality) tuples representing +# one-way moves or two-way swaps of requests in batch +MovedRequest = tuple[int, int, MoveDirectionality] + +# Batch indices of any removed requests. +RemovedRequest = int + + +@dataclass(frozen=True) +class BatchUpdate: + """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch + + # Metadata for requests added to, removed from, and moved + # within the persistent batch. + # + # Key assumption: the `output_tok_ids` list (which is an element of each + # tuple in `added`) is a reference to the request's running output tokens + # list; via this reference, the logits processors always see the latest + # list of generated output tokens + removed: Sequence[RemovedRequest] + moved: Sequence[MovedRequest] + added: Sequence[AddedRequest] + + +class LogitsProcessor(ABC): + + @abstractmethod + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool) -> None: + raise NotImplementedError + + @abstractmethod + def apply(self, logits: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def is_argmax_invariant(self) -> bool: + """True if logits processor has no impact on the + argmax computation in greedy sampling. + NOTE: may or may not have the same value for all + instances of a given LogitsProcessor subclass, + depending on subclass implementation. + """ + raise NotImplementedError + + @abstractmethod + def update_state( + self, + batch_update: Optional["BatchUpdate"], + ) -> None: + """Called when there are new output tokens, prior + to each forward pass. + + Args: + batch_update is non-None iff there have been + changes to the batch makeup. + """ + raise NotImplementedError diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py new file mode 100644 index 0000000000000..31cece58c7db5 --- /dev/null +++ b/vllm/v1/sample/logits_processor/state.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator +from itertools import chain +from typing import TYPE_CHECKING, Optional + +from vllm.v1.sample.logits_processor.interface import (AddedRequest, + BatchUpdate, + MovedRequest, + RemovedRequest) + +if TYPE_CHECKING: + from vllm.v1.sample.logits_processor.interface import LogitsProcessor + + +class BatchUpdateBuilder: + """Helps track persistent batch state changes and build + a batch update data structure for logitsprocs + Assumptions: + * All information about requests removed from persistent batch + during a step is aggregated in self._removed through calls to + self.removed_append() at the beginning of a step. This must happen + before the first time that self.removed, self.pop_removed() + or self.peek_removed() are invoked in a given step + * After the first time that self.removed, self.pop_removed() + or self.peek_removed() are read in a step, no new removals + are registered using self.removed_append() + * Elements of self._removed are never directly modified, added or + removed (i.e. modification is only via self.removed_append() and + self.pop_removed()) + Guarantees under above assumptions: + * self.removed is always sorted in descending order + * self.pop_removed() and self.peek_removed() both return + the lowest removed request index in the current step + """ + + _removed: list[RemovedRequest] + _is_removed_sorted: bool + moved: list[MovedRequest] + added: list[AddedRequest] + + def __init__( + self, + removed: Optional[list[RemovedRequest]] = None, + moved: Optional[list[MovedRequest]] = None, + added: Optional[list[AddedRequest]] = None, + ) -> None: + self._removed = removed or [] + self.moved = moved or [] + self.added = added or [] + self._is_removed_sorted = False + + # Used to track changes in the pooling case + # where we don't populate the added list. + self.batch_changed = False + + def _ensure_removed_sorted(self) -> None: + """Sort removed request indices in + descending order. + Idempotent after first call in a + given step, until reset. + """ + if not self._is_removed_sorted: + self._removed.sort(reverse=True) + self._is_removed_sorted = True + + @property + def removed(self) -> list[RemovedRequest]: + """Removed request indices sorted in + descending order""" + self._ensure_removed_sorted() + return self._removed + + def removed_append(self, index: int) -> None: + """Register the removal of a request from the persistent batch. + + Must not be called after the first time self.removed, + self.pop_removed() or self.peek_removed() are invoked. + + Args: + index: request index + """ + if self._is_removed_sorted: + raise RuntimeError("Cannot register new removed request after" + " self.removed has been read.") + self._removed.append(index) + self.batch_changed = True + + def has_removed(self) -> bool: + return bool(self._removed) + + def peek_removed(self) -> Optional[int]: + """Return lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed[-1] + return None + + def pop_removed(self) -> Optional[int]: + """Pop lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed.pop() + return None + + def reset(self) -> bool: + """Returns True if there were any changes to the batch.""" + self._is_removed_sorted = False + self._removed.clear() + self.moved.clear() + self.added.clear() + batch_changed = self.batch_changed + self.batch_changed = False + return batch_changed + + def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: + """Generate a logitsprocs batch update data structure and reset + internal batch update builder state. + + Args: + batch_size: current persistent batch size + + Returns: + Frozen logitsprocs batch update instance; `None` if no updates + """ + # Reset removal-sorting logic + self._is_removed_sorted = False + self.batch_changed = False + if not any((self._removed, self.moved, self.added)): + # No update; short-circuit + return None + # Build batch state update + batch_update = BatchUpdate( + batch_size=batch_size, + removed=self._removed, + moved=self.moved, + added=self.added, + ) + self._removed = [] + self.moved = [] + self.added = [] + return batch_update + + +class LogitsProcessors: + """Encapsulates initialized logitsproc objects.""" + + def __init__( + self, + logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None: + self.argmax_invariant: list[LogitsProcessor] = [] + self.non_argmax_invariant: list[LogitsProcessor] = [] + if logitsprocs: + for logitproc in logitsprocs: + (self.argmax_invariant if logitproc.is_argmax_invariant() else + self.non_argmax_invariant).append(logitproc) + + @property + def all(self) -> Iterator["LogitsProcessor"]: + """Iterator over all logits processors.""" + return chain(self.argmax_invariant, self.non_argmax_invariant) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 1189b12f30776..9d6a87cea3d07 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -6,7 +6,7 @@ from typing import Optional import torch -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors @dataclass @@ -40,4 +40,4 @@ class SamplingMetadata: bad_words_token_ids: dict[int, list[list[int]]] # Loaded logits processors - logitsprocs: LogitsProcessorManager + logitsprocs: LogitsProcessors diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e0434c8f3d713..7bd4a5a380ac0 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -8,6 +8,7 @@ import torch.nn as nn from packaging import version from vllm import envs +from vllm.config import LogprobsMode from vllm.logger import init_logger from vllm.platforms import current_platform @@ -28,9 +29,16 @@ class TopKTopPSampler(nn.Module): Implementations may update the logits tensor in-place. """ - def __init__(self): + def __init__( + self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None: super().__init__() - if current_platform.is_cuda(): + self.logprobs_mode = logprobs_mode + # flashinfer optimization does not apply if intermediate + # logprobs/logits after top_k/top_p need to be returned + if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS, + LogprobsMode.PROCESSED_LOGPROBS + ) and current_platform.is_cuda(): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ if version.parse(flashinfer_version) < version.parse("0.2.3"): @@ -63,10 +71,12 @@ class TopKTopPSampler(nn.Module): "native implementation of top-p & top-k sampling. For the " "best performance, please install FlashInfer.") self.forward = self.forward_native - elif current_platform.is_tpu(): - self.forward = self.forward_tpu else: self.forward = self.forward_native + if current_platform.is_tpu(): + self.apply_top_k_top_p = apply_top_k_top_p_tpu + else: + self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self, @@ -74,15 +84,20 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ PyTorch-native implementation of top-k and top-p sampling. The logits tensor may be updated in-place. """ - logits = apply_top_k_top_p(logits, k, p) + logits = self.apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + logits_to_return = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return random_sample(probs, generators), logits_to_return def forward_cuda( self, @@ -90,34 +105,24 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """More optimized implementation for top-k and top-p sampling.""" - if k is None and p is None: - # We prefer `random_sample` over `flashinfer_sample` when sorting is - # not needed. This is because `random_sample` does not require - # CPU-GPU synchronization while `flashinfer_sample` does. - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) - if generators: - logger.warning_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. + if (k is None and p is None) or generators: + if generators: + logger.warning_once("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") return self.forward_native(logits, generators, k, p) + assert self.logprobs_mode not in ( + LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS + ), "FlashInfer does not support returning logits/logprobs" # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. - return flashinfer_sample(logits.contiguous(), k, p, generators) - - def forward_tpu( - self, - logits: torch.Tensor, - generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> torch.Tensor: - logits = apply_top_k_top_p_tpu(logits, k, p) - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return flashinfer_sample(logits.contiguous(), k, p, generators), None def apply_top_k_top_p_tpu( diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 82f51298f1b59..546531a91610f 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that samples the next tokens from the model's outputs.""" +from typing import Optional + import torch import torch.nn as nn @@ -18,10 +20,50 @@ _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): + """ + A layer that samples the next tokens from the model's outputs + with the following steps in order: - def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): + 1. If logprobs are requested: + a) If `logprobs_mode` is `raw_logprobs`, compute logprobs + as the final logprobs to return. + b) If `logprobs_mode` is `raw_logits`, clone the logits + as the final logprobs to return. + 2. Convert logits to float32. + 3. Apply allowed token ids whitelist. + 4. Apply bad words exclusion. + 5. Apply logit processors which are not argmax-invariant, + i.e. that can impact greedy sampling. + a) Min tokens processor + b) Logit bias processor + 6. Apply penalties + a) Repetition penalty + b) Frequency penalty + c) Presence penalty + 7. Sample the next tokens. `sample` method performs the following steps: + a) If not `all_random`, perform greedy sampling. If `all_greedy`, + return the greedily sampled tokens and final logprobs if requested. + b) Apply temperature. + c) Apply logit processors which are argmax-invariant, by default + the min_p processor. + d) Apply top_k and/or top_p. + e) Sample the next tokens with the probability distribution. + f) If `all_random` or temperature >= epsilon (1e-5), return the + randomly sampled tokens and final logprobs if requested. Else, + return the greedily sampled tokens and logprobs if requested. + 8. Gather the logprobs of the top `max_num_logprobs` and sampled token + (if requested). Note that if the sampled token is within the top + `max_num_logprobs`, the logprob will be eventually merged in + `LogprobsProcessor` during output processing. Therefore, the + final output may contain either `max_num_logprobs + 1` or + `max_num_logprobs` logprobs. + 9. Return the final `SamplerOutput`. + """ + + def __init__(self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS): super().__init__() - self.topk_topp_sampler = TopKTopPSampler() + self.topk_topp_sampler = TopKTopPSampler(logprobs_mode) self.pin_memory = is_pin_memory_available() self.logprobs_mode = logprobs_mode @@ -34,13 +76,11 @@ class Sampler(nn.Module): # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). - # TODO(rob): provide option for logprobs post sampling. - # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - if self.logprobs_mode == "raw_logprobs": + if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS: raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "raw_logits": + elif self.logprobs_mode == LogprobsMode.RAW_LOGITS: raw_logprobs = logits.clone() # Use float32 for the logits. @@ -51,21 +91,16 @@ class Sampler(nn.Module): logits = self.apply_bad_words(logits, sampling_metadata) # Apply logits processors which can impact greedy sampling - for processor in (sampling_metadata.logitsprocs.non_argmax_invariant): + for processor in sampling_metadata.logitsprocs.non_argmax_invariant: logits = processor.apply(logits) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) - # Get the process logprobs or logits. - if num_logprobs is not None: - if self.logprobs_mode == "processed_logprobs": - raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "processed_logits": - raw_logprobs = logits.clone() - # Sample the next token. - sampled = self.sample(logits, sampling_metadata) + sampled, processed_logprobs = self.sample(logits, sampling_metadata) + if processed_logprobs is not None: + raw_logprobs = processed_logprobs # Convert sampled token ids to int64 (long) type to ensure compatibility # with subsequent operations that may use these values as indices. # This conversion is necessary because FlashInfer sampling operations @@ -105,7 +140,7 @@ class Sampler(nn.Module): self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Sample logits based on sampling metadata. The various logits processing functions called in this method @@ -119,7 +154,13 @@ class Sampler(nn.Module): else: greedy_sampled = self.greedy_sample(logits) if sampling_metadata.all_greedy: - return greedy_sampled + processed_logprobs = None + if sampling_metadata.max_num_logprobs is not None: + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + processed_logprobs = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + processed_logprobs = self.compute_logprobs(logits) + return greedy_sampled, processed_logprobs assert sampling_metadata.temperature is not None @@ -132,7 +173,7 @@ class Sampler(nn.Module): logits = processor.apply(logits) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, processed_logprobs = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, @@ -140,7 +181,7 @@ class Sampler(nn.Module): ) if greedy_sampled is None: - return random_sampled + return random_sampled, processed_logprobs sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, @@ -148,7 +189,7 @@ class Sampler(nn.Module): random_sampled, out=greedy_sampled, # Reuse tensor ) - return sampled + return sampled, processed_logprobs def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 2c9f4892bc247..04545d587e4a9 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -65,7 +65,7 @@ class Sampler(nn.Module): logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, _ = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 3f0fad8a64d0a..c8375d6f15517 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -18,12 +18,15 @@ from msgspec import msgpack from vllm import envs from vllm.logger import init_logger +# yapf: disable from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalBatchedField, MultiModalFieldConfig, MultiModalFieldElem, MultiModalFlatField, MultiModalKwargs, MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalSharedField, NestedTensors) +# yapf: enable from vllm.v1.engine import UtilityResult logger = init_logger(__name__) @@ -116,19 +119,11 @@ class MsgpackEncoder: if isinstance(obj, MultiModalKwargsItem): return self._encode_mm_item(obj) - if isinstance(obj, MultiModalKwargs): - mm: MultiModalKwargs = obj - if not mm.modalities: - # just return the main dict if there are no modalities. - return dict(mm) + if isinstance(obj, MultiModalKwargsItems): + return self._encode_mm_items(obj) - # ignore the main dict, it will be re-indexed. - # Any tensors *not* indexed by modality will be ignored. - return [ - self._encode_mm_item(item) - for itemlist in mm._items_by_modality.values() - for item in itemlist - ] + if isinstance(obj, MultiModalKwargs): + return self._encode_mm_kwargs(obj) if isinstance(obj, UtilityResult): result = obj.result @@ -190,6 +185,12 @@ class MsgpackEncoder: dtype = str(obj.dtype).removeprefix("torch.") return dtype, obj.shape, data + def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]: + return { + modality: [self._encode_mm_item(item) for item in itemlist] + for modality, itemlist in items.items() + } + def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]: return [self._encode_mm_field_elem(elem) for elem in item.values()] @@ -207,6 +208,12 @@ class MsgpackEncoder: self._encode_mm_field(elem.field), } + def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]: + return { + modality: self._encode_nested_tensors(data) + for modality, data in kw.items() + } + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): return self._encode_tensor(nt) @@ -267,14 +274,10 @@ class MsgpackDecoder: return slice(*obj) if issubclass(t, MultiModalKwargsItem): return self._decode_mm_item(obj) + if issubclass(t, MultiModalKwargsItems): + return self._decode_mm_items(obj) if issubclass(t, MultiModalKwargs): - if isinstance(obj, list): - return MultiModalKwargs.from_items( - self._decode_mm_items(obj)) - return MultiModalKwargs({ - k: self._decode_nested_tensors(v) - for k, v in obj.items() - }) + return self._decode_mm_kwargs(obj) if t is UtilityResult: return self._decode_utility_result(obj) return obj @@ -328,8 +331,11 @@ class MsgpackDecoder: # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) - def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]: - return [self._decode_mm_item(v) for v in obj] + def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: + return MultiModalKwargsItems({ + modality: [self._decode_mm_item(item) for item in itemlist] + for modality, itemlist in obj.items() + }) def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: return MultiModalKwargsItem.from_elems( @@ -352,6 +358,12 @@ class MsgpackDecoder: obj["field"] = factory_meth(None, *field_args).field return MultiModalFieldElem(**obj) + def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs: + return MultiModalKwargs({ + modality: self._decode_nested_tensors(data) + for modality, data in obj.items() + }) + def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)): # Although it violates NestedTensors type, MultiModalKwargs diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a8a160a0f9953..0a0e9fed725cb 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast from dataclasses import replace -from typing import Optional +from importlib.util import find_spec +from typing import Optional, Protocol import numpy as np import torch @@ -20,8 +21,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, TreeAttentionMetadataBuilder) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata @@ -34,6 +33,17 @@ logger = init_logger(__name__) PADDING_SLOT_ID = -1 +class EagleAttentionMetadata(Protocol): + # Required attributes + num_actual_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + class EagleProposer: def __init__( @@ -97,6 +107,20 @@ class EagleProposer: dtype=self.dtype, device=device) + # Determine allowed attention backends once during initialization. + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + if current_platform.is_rocm(): + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend + if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + rocm_types.append(AiterFlashAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) + else: + self.allowed_attn_types = (FlashAttentionMetadata, + TreeAttentionMetadata) + # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree self.tree_choices: list[tuple[int, @@ -165,7 +189,7 @@ class EagleProposer: for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -194,7 +218,7 @@ class EagleProposer: hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method == "deepseek_mtp": + if self.method in ("deepseek_mtp", "ernie_mtp"): last_hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states @@ -225,25 +249,13 @@ class EagleProposer: # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - - # On ROCm, both AiterFlashAttention and TritonAttention - # support multi-token eagle spec decode. - if current_platform.is_rocm(): - assert isinstance( - attn_metadata, - (TritonAttentionMetadata, AiterFlashAttentionMetadata, - FlashAttentionMetadata)) - else: - # Currently, only FlashAttention supports multi-token eagle spec - # decode. This is because the code below makes assumptions about - # attn_metadata attributes available. - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert isinstance(attn_metadata, self.allowed_attn_types) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size @@ -449,7 +461,7 @@ class EagleProposer: num_tokens, -1) if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph( num_tokens) else: @@ -508,19 +520,19 @@ class EagleProposer: """ # E.g. # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] + # [0, q1, q1 + q2, q1 + q2 + q3] # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] # num_rejected_tokens: [n1, n2, n3] # This function computes the intermediate values: # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] # And returns: # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -564,9 +576,9 @@ class EagleProposer: old_query_start_locs_expanded = np.repeat( query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to( device, non_blocking=True) @@ -582,6 +594,7 @@ class EagleProposer: num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, @@ -615,20 +628,18 @@ class EagleProposer: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + and self.model.model.embed_tokens.weight.shape \ + == target_language_model.model.embed_tokens.weight.shape: logger.info( - "Assuming the EAGLE head shares the same vocab embedding" \ - " with the target model." - ) + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") del self.model.model.embed_tokens self.model.model.embed_tokens = ( target_language_model.model.embed_tokens) else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" \ - " from the target model." - ) + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") # share lm_head with the target model if needed # some model definition do not define lm_head explicitly diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 309fd926aecd7..3e90179e78d99 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -38,12 +38,14 @@ class MedusaProposer: self, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> list[list[int]]: # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks, None) # Get draft tokens and transpose the result + # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU + # synchronization. draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] return [list(row) for row in zip(*draft_tokens)] diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 63604a335d9f0..57854cc112041 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -108,6 +108,14 @@ class StructuredOutputManager: tokenizer=self.tokenizer, vocab_size=vocab_size, ) + elif backend == "lm-format-enforcer": + from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501 + LMFormatEnforcerBackend) + self.backend = LMFormatEnforcerBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend}") @@ -267,7 +275,7 @@ class StructuredOutputManager: assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None # by default, we should always advance - # for cases that doesn't uses thinking mode. + # for cases that don't use thinking mode. if self.reasoner is not None: structured_req = request.structured_output_request @@ -276,7 +284,7 @@ class StructuredOutputManager: # Check if reasoning ends in *this* step if self.reasoner.is_reasoning_end(request.all_token_ids): - # Reasoning just ended, so we shouldn't advanced til + # Reasoning just ended, so we shouldn't advance til # next pass structured_req.reasoning_ended = True diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py new file mode 100644 index 0000000000000..2279a1c8c8a00 --- /dev/null +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import ast +import json +from dataclasses import dataclass, field +from functools import lru_cache +from typing import TYPE_CHECKING + +import torch +from transformers import PreTrainedTokenizerBase + +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) + +if TYPE_CHECKING: + import lmformatenforcer + import lmformatenforcer.integrations.vllm as lmfe_vllm +else: + lmformatenforcer = LazyLoader("lmformatenforcer", globals(), + "lmformatenforcer") + lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(), + "lmformatenforcer.integrations.vllm") + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase, + vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData: + return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( + tokenizer, use_bitmask=True, vocab_size=vocab_size) + + +@dataclass +class LMFormatEnforcerGrammar(StructuredOutputGrammar): + token_enforcer: lmformatenforcer.TokenEnforcer + current_tokens_prefix: list[int] = field(default_factory=list) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + original_len = len(self.current_tokens_prefix) + for token in tokens: + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix).is_token_allowed(token): + # Rollback partial updates to ensure atomicity. + del self.current_tokens_prefix[original_len:] + return False + self.current_tokens_prefix.append(token) + return True + + def validate_tokens(self, tokens: list[int]) -> list[int]: + for prefix_length in range(len(tokens)): + prefix = tokens[:prefix_length] + next_token = tokens[prefix_length] + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix + + prefix).is_token_allowed(next_token): + break + else: + return tokens + + return tokens[:prefix_length] + + def rollback(self, num_tokens: int) -> None: + self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens] + + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + allowed_tokens = self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix) + bitmask[batch_index] = allowed_tokens.allowed_tokens + + def is_terminated(self) -> bool: + # We are considered terminated if the prefix ends with eos_token_id + return_value = len( + self.current_tokens_prefix) > 0 and self.current_tokens_prefix[ + -1] == self.token_enforcer.eos_token_id + return return_value + + def reset(self): + self.current_tokens_prefix = [] + + +@dataclass +class LMFormatEnforcerBackend(StructuredOutputBackend): + + def __post_init__(self): + self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + self.tokenizer, self.vocab_size) + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + character_level_parser: lmformatenforcer.CharacterLevelParser + if request_type == StructuredOutputOptions.JSON: + spec_dict = json.loads(grammar_spec) + character_level_parser = lmformatenforcer.JsonSchemaParser( + spec_dict) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + character_level_parser = lmformatenforcer.JsonSchemaParser(None) + elif request_type == StructuredOutputOptions.REGEX: + character_level_parser = lmformatenforcer.RegexParser(grammar_spec) + elif request_type == StructuredOutputOptions.CHOICE: + choices = ast.literal_eval(grammar_spec) + character_level_parser = lmformatenforcer.UnionParser( + [lmformatenforcer.StringParser(choice) for choice in choices]) + else: + raise ValueError( + "Invalid request type for LM Format Enforcer backend" + f"({request_type!s})") + max_rollback_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config is not None else 0) + + if max_rollback_tokens > 0: + raise ValueError( + "LM Format Enforcer backend does not support speculative tokens" + ) + + token_enforcer = lmformatenforcer.TokenEnforcer( + tokenizer_data=self.tokenizer_data, + parser=character_level_parser, + ) + return LMFormatEnforcerGrammar(token_enforcer) + + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + return torch.full( + (max_num_seqs, (self.vocab_size + 31) // 32), + -1, + dtype=torch.int32, + pin_memory=torch.cuda.is_available(), + ) + + def destroy(self): + pass + + +def validate_structured_output_request_lm_format_enforcer( + params: SamplingParams): + if params.guided_decoding is None: + return + + gd_params = params.guided_decoding + + if gd_params.regex: + return + elif gd_params.json: + if isinstance(gd_params.json, str): + try: + # make sure schema is valid json + json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + try: + json.dumps(gd_params.json) + except Exception as e: + raise ValueError( + f"Error serializing guided decoding jsonschema: {e}" + ) from e + return + elif gd_params.choice: + return + elif gd_params.grammar: + raise ValueError("LM Format Enforcer guided decoding backend " + "does not support grammar specifications") diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b5750c82db023..8f9face6fbf2e 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -96,6 +96,35 @@ class ConstantList(Generic[T], Sequence): return f"ConstantList({self._x})" +class CpuGpuBuffer: + + def __init__( + self, + *args, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + ): + self.cpu = torch.zeros(*args, + dtype=dtype, + device="cpu", + pin_memory=pin_memory) + self.np = self.cpu.numpy() + self.gpu = self.cpu.to(device) + + def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: + if n is None: + return self.gpu.copy_(self.cpu, non_blocking=True) + return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) + + def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor: + """NOTE: Because this method is non-blocking, explicit synchronization + is needed to ensure the data is copied to CPU.""" + if n is None: + return self.cpu.copy_(self.gpu, non_blocking=True) + return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True) + + def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str: diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index bf38e88f0c2a1..5662fc350e198 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -91,8 +91,7 @@ class BlockTable: # block_size. block_table_indices = (req_indices * self.max_num_blocks_per_req + positions // self.block_size) - block_table_cpu = self.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_numbers = self.block_table_np.ravel()[block_table_indices] block_offsets = positions % self.block_size np.add(block_numbers * self.block_size, block_offsets, diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 11b96d946365d..742e553b77e09 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1 +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner if TYPE_CHECKING: @@ -21,7 +22,8 @@ logger = init_logger(__name__) class CPUModelRunner(GPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) assert device == torch.device("cpu") assert self.speculative_config is None, "spec decode is not supported." @@ -29,7 +31,7 @@ class CPUModelRunner(GPUModelRunner): self.use_cuda_graph = False self.cascade_attn_enabled = False - self._postprocess_tenosrs() + self._postprocess_tensors() def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -59,7 +61,7 @@ class CPUModelRunner(GPUModelRunner): self.attn_groups[0][0].metadata_builder.reorder_batch( self.input_batch, scheduler_output) - def _postprocess_tenosrs(self) -> None: + def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: @@ -71,8 +73,8 @@ class CPUModelRunner(GPUModelRunner): setattr(obj, device_attr_name, cpu_tensor) for k, v in vars(self).items(): - if k.endswith("_cpu") and isinstance(v, torch.Tensor): - replace_tensor(self, k, k[:-4]) + if isinstance(v, CpuGpuBuffer): + v.gpu = v.cpu for k, v in vars(self.input_batch).items(): if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor): @@ -108,6 +110,26 @@ class CPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: pass + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + return sampled_token_ids.tolist() + + +@contextmanager +def _torch_cuda_wrapper(): + + class _EventPlaceholder: + + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + try: + cuda_event = torch.cuda.Event + torch.cuda.Event = _EventPlaceholder + yield + finally: + torch.cuda.Event = cuda_event + @contextmanager def _set_global_compilation_settings(config: VllmConfig): diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 2dc28d93049ab..be78597926e09 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -43,8 +43,9 @@ class CPUWorker(Worker): # Setup OpenMP threads affinity. omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND if omp_cpuids == "auto" and platform.system() == "Linux": - if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC: - # For POWERPC SMT-8/4/2 + cpu_arch = current_platform.get_cpu_architecture() + if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X): + # For S390X/POWERPC SMT-8/4/2 self.local_omp_cpuid = self._get_autobind_cpu_ids( lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]) elif current_platform.get_cpu_architecture() == CpuArchEnum.X86: @@ -132,7 +133,7 @@ class CPUWorker(Worker): """ allowed_numa_nodes, logical_cpu_list = \ - CpuPlatform.get_allowed_cpu_memory_node_list() + CpuPlatform.get_allowed_cpu_core_node_list() assert len(allowed_numa_nodes) >= self.parallel_config.world_size, ( f"No enough allowed NUMA nodes to bind threads of " f"{self.parallel_config.world_size} CPUWorkers. " diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 2469e09f8249d..284af6bfedce0 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,16 +10,16 @@ import torch from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalKwargsItem, + MultiModalKwargsItems, PlaceholderRange) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - MoveDirectionality, - init_builtin_logitsprocs) + LogitsProcessors, + MoveDirectionality) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice @@ -33,6 +33,7 @@ class CachedRequestState: prompt_token_ids: list[int] mm_kwargs: list[MultiModalKwargsItem] mm_positions: list[PlaceholderRange] + mm_hashes: list[str] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -57,14 +58,15 @@ class CachedRequestState: @property @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " "removed in v0.13. Please use `mm_kwargs` instead.") - def mm_inputs(self) -> list[MultiModalKwargs]: - return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs] + def mm_inputs(self) -> list[MultiModalKwargsItems]: + return [ + MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs + ] def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - else: - return self.output_token_ids[idx - self.num_prompt_tokens] + return self.output_token_ids[idx - self.num_prompt_tokens] class InputBatch: @@ -78,8 +80,11 @@ class InputBatch: pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, + is_pooling_model: bool = False, ): + self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -221,14 +226,6 @@ class InputBatch: # updates. Should reset each step. self.batch_update_builder = BatchUpdateBuilder() - # Define logits processors. - # TODO(andy): logits processor list should be extensible via engine - # constructor argument; for now the list is fixed. - self.logitsprocs = init_builtin_logitsprocs( - pin_memory_available=pin_memory, - max_num_reqs=max_num_reqs + 1, - device=device) - # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, @@ -244,6 +241,10 @@ class InputBatch: self.req_output_token_ids: list[Optional[list[int]]] = [] + # Store provided logitsprocs. If none are provided, initialize empty + # data structure + self.logitsprocs = logitsprocs or LogitsProcessors() + # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() @@ -255,22 +256,26 @@ class InputBatch: # while performing state updates to the batch. return cast(list[str], self._req_ids) - def _get_next_add_index(self) -> int: - if (req_index := self.batch_update_builder.pop_removed()) is not None: - # Fill the empty index. - return req_index - # Append to end - return self.num_reqs - def _register_add_request(self, request: "CachedRequestState") -> int: - """Track add-request operations""" - req_index = self._get_next_add_index() - assert req_index < self.max_num_reqs - params = (request.sampling_params - if request.sampling_params else request.pooling_params) - self.batch_update_builder.added.append( - (req_index, params, request.output_token_ids)) - return req_index + """Track add-request operations for logits processors. + Not applicable to pooling models. + """ + + # Fill the next empty index if there is one. + if (new_req_index := self.batch_update_builder.pop_removed()) is None: + # Append to end otherwise. + new_req_index = self.num_reqs + + assert new_req_index < self.max_num_reqs + self.batch_update_builder.batch_changed = True + if request.sampling_params: + # Detailed added request metadata is only required for non-pooling + # models, to support logitsprocs. + self.batch_update_builder.added.append( + (new_req_index, request.sampling_params, + request.prompt_token_ids, request.output_token_ids)) + + return new_req_index def add_request( self, @@ -381,7 +386,7 @@ class InputBatch: self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids) else: - raise NotImplementedError(request) + raise NotImplementedError("Unrecognized request type") # Add request lora ID if request.lora_request: @@ -411,10 +416,25 @@ class InputBatch: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None + self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + lora_req_ids = self.lora_id_to_request_ids[lora_id] + lora_req_ids.discard(req_id) + if not lora_req_ids: + del self.lora_id_to_request_ids[lora_id] + del self.lora_id_to_lora_request[lora_id] + self.request_lora_mapping[req_index] = 0 + + if self.is_pooling_model: + self.pooling_params.pop(req_id, None) + return req_index + self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) @@ -428,26 +448,14 @@ class InputBatch: self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) - # LoRA - lora_id = self.request_lora_mapping[req_index] - if lora_id != 0: - self.lora_id_to_request_ids[lora_id].discard(req_id) - if len(self.lora_id_to_request_ids[lora_id]) == 0: - self.lora_id_to_request_ids.pop(lora_id) - self.lora_id_to_lora_request.pop(lora_id) - self.request_lora_mapping[req_index] = 0 - self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) - self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] self._req_ids[i1], self._req_ids[i2] =\ @@ -465,18 +473,6 @@ class InputBatch: self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -487,18 +483,41 @@ class InputBatch: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + self.block_table.swap_row(i1, i2) + + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ + self.request_lora_mapping[i2], self.request_lora_mapping[i1] + + if self.is_pooling_model: + # Sampling and logits parameters don't apply to pooling models. + return + + # For autoregressive models, track detailed request reordering info + # to support logitsprocs. + self.batch_update_builder.moved.append( + (i1, i2, MoveDirectionality.SWAP)) + + self.temperature_cpu[i1], self.temperature_cpu[i2] = \ + self.temperature_cpu[i2], self.temperature_cpu[i1] + self.top_p_cpu[i1], self.top_p_cpu[i2] = \ + self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = \ + self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \ + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \ + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] - if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[i1], \ self.allowed_token_ids_mask_cpu_tensor[i2] =\ self.allowed_token_ids_mask_cpu_tensor[i2], \ self.allowed_token_ids_mask_cpu_tensor[i1] - self.block_table.swap_row(i1, i2) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -513,11 +532,12 @@ class InputBatch: swaps: list of (from,to) swap tuples for moved requests empty_req_indices: indices not filled by condensation """ + num_reqs = self.num_reqs + if not (empty_req_indices := self.batch_update_builder.removed): # All removed requests were replaced by added requests, or else no # requests were removed at all. No condense() needed return - num_reqs = self.num_reqs if num_reqs == 0: # The batched states are empty. self._req_ids.clear() @@ -541,9 +561,6 @@ class InputBatch: # Move active request down into empty request # index. self.batch_update_builder.pop_removed() - self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) req_id = self._req_ids[last_req_index] output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None @@ -564,6 +581,21 @@ class InputBatch: self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table.move_row(last_req_index, empty_index) + + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + + if self.is_pooling_model: + last_req_index -= 1 + # Samping state not used by pooling models. + continue + + # Autoregressive models require detailed tracking of condense + # operations to support logitsprocs + self.batch_update_builder.moved.append( + (last_req_index, empty_index, + MoveDirectionality.UNIDIRECTIONAL)) + self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] @@ -578,9 +610,6 @@ class InputBatch: if generator is not None: self.generators[empty_index] = generator - self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] - # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[ @@ -596,15 +625,21 @@ class InputBatch: last_req_index -= 1 # Trim lists to the batch size. - del self._req_ids[self.num_reqs:] - del self.req_output_token_ids[self.num_reqs:] + del self._req_ids[num_reqs:] + del self.req_output_token_ids[num_reqs:] def refresh_metadata(self): - """Apply batch updates, reset input batch at end of step + """Apply any batch updates to sampling metadata.""" - * Apply batch add/remove/permute to logits procs' states - * If batch state is modified, update sampling metadata - """ + if self.is_pooling_model: + batch_changed = self.batch_update_builder.reset() + if batch_changed: + self.sampling_metadata = self._make_sampling_metadata() + return + + # For non-pooling models - generate and apply logitsprocs update; + # reset batch update tracking. + # Update sampling metadata if batch state is changed. batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) for logit_proc in self.logitsprocs.all: logit_proc.update_state(batch_update) @@ -686,13 +721,14 @@ class InputBatch: return PoolingMetadata( prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]).to(self.device), + self.num_prompt_tokens[:self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + num_reqs = self.num_reqs + max_prompt_len = self.num_prompt_tokens[:num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", @@ -700,11 +736,10 @@ class InputBatch: pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] + prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. - for i in range(self.num_reqs): + for i in range(num_reqs): prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9460d91c58323..d93460d618e7c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,6 +8,7 @@ import time from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager +from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np @@ -34,7 +35,8 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import (BatchDescriptor, DPMetadata, set_forward_context) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (is_mixture_of_experts, @@ -54,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_kv_sharing_fast_prefill_attention_metadata, @@ -62,12 +63,14 @@ from vllm.v1.attention.backends.utils import ( from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, + EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, - SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, SlidingWindowSpec) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, + LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -75,12 +78,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from ..sample.logits_processor import LogitsProcessorManager from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -136,7 +139,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self.is_pooling_model = model_config.pooler_config is not None - self.is_encoder_only_model = False self.is_multimodal_raw_input_supported = ( model_config.is_multimodal_raw_input_supported) self.max_model_len = model_config.max_model_len @@ -148,6 +150,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size + # Only relevant for models using ALiBi (e.g, MPT) + self.use_alibi = check_use_alibi(model_config) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn @@ -175,8 +179,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} self.use_aux_hidden_state_outputs = False # Set up speculative decoding. @@ -221,37 +225,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, self.device, self.pin_memory, + self.is_pooling_model, + self.vllm_config.model_config.logits_processors), + is_pooling_model=self.is_pooling_model, ) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. - self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) + if self.compilation_config.cudagraph_capture_sizes and \ + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + self.cudagraph_batch_sizes = list( + reversed(self.compilation_config.cudagraph_capture_sizes)) # Cache the device properties. self._init_device_properties() # Persistent buffers for CUDA graphs. - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - self.query_start_loc = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device=self.device) - self.seq_lens = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - - # None in the first PP rank. The rest are set after load_model. - self.intermediate_tensors: Optional[IntermediateTensors] = None + self.input_ids = self._make_buffer(self.max_num_tokens, + dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, + dtype=torch.int64) + self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, + dtype=torch.int32) + self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -265,23 +269,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # identical position IDs, making M-RoPE functionally equivalent to # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 - self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), - dtype=torch.int64, - device=self.device) - self.mrope_positions_cpu = torch.zeros( - (3, self.max_num_tokens + 1), - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.mrope_positions_np = self.mrope_positions_cpu.numpy() + self.mrope_positions = self._make_buffer( + (3, self.max_num_tokens + 1), dtype=torch.int64) - # Only relevant for models using ALiBi (e.g, MPT) - self.use_alibi = check_use_alibi(model_config) - - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: Optional[IntermediateTensors] = None # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context @@ -289,28 +281,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.max_model_len, self.max_num_tokens), dtype=np.int64) - # NOTE(woosuk): These tensors are "stateless", i.e., they are literally - # a faster version of creating a new tensor every time. Thus, we should - # not make any assumptions about the values in these tensors. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.positions_np = self.positions_cpu.numpy() - self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.seq_lens_np = self.seq_lens_cpu.numpy() # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -334,13 +304,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model_config, self.scheduler_config, self.mm_registry, - max_model_len=self.max_model_len, - max_num_reqs=self.max_num_reqs, - ) if self.supports_mm_inputs \ - else None) + ) if self.supports_mm_inputs else None) self.reorder_batch_threshold: Optional[int] = None + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + self.runner_only_attn_layers: set[str] = set() + + # Cached outputs. + self._draft_token_ids: Optional[Union[list[list[int]], + torch.Tensor]] = None + self.transfer_event = torch.cuda.Event() + self.sampled_token_ids_pinned_cpu = torch.empty( + (self.max_model_len, 1), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer(*args, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory) + def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() num_reqs = self.input_batch.num_reqs @@ -350,6 +338,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if num_pooling_reqs == 0: return model_kwargs + # This does nontrivial work. pooling_params = self.input_batch.pooling_metadata.pooling_params assert num_pooling_reqs == num_reqs @@ -364,7 +353,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if len(token_type_id_requests) == 0: return model_kwargs - seq_lens = self.seq_lens[:num_reqs] + seq_lens = self.seq_lens.gpu[:num_reqs] token_type_ids = [] for i in range(num_reqs): @@ -424,7 +413,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -435,12 +423,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_batch.remove_request(req_id) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -457,7 +441,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) - req_ids_to_add: list[str] = [] + reqs_to_add: list[CachedRequestState] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id @@ -472,18 +456,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): generator = None if pooling_params: - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" model = cast(VllmModelForPooling, self.get_model()) to_update = model.pooler.get_pooling_updates(task) to_update.apply(pooling_params) - self.requests[req_id] = CachedRequestState( + req_state = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, + mm_hashes=new_req_data.mm_hashes, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -492,46 +477,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): output_token_ids=[], lora_request=new_req_data.lora_request, ) + self.requests[req_id] = req_state # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for item in self.requests[req_id].mm_kwargs: - mm_input = item.require_data() - if mm_input.get("image_grid_thw") is not None: - image_grid_thw.append( - mm_input["image_grid_thw"].tolist()) - if mm_input.get("video_grid_thw") is not None: - video_grid_thw.append( - mm_input["video_grid_thw"].tolist()) - if mm_input.get("second_per_grid_ts") is not None: - second_per_grid_ts.append( - mm_input["second_per_grid_ts"]) - if mm_input.get("audio_feature_lengths") is not None: - audio_feature_lengths.append( - mm_input["audio_feature_lengths"]) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True + self._init_mrope_positions(req_state) - hf_config = self.model_config.hf_config - - self.requests[req_id].mrope_positions, \ - self.requests[req_id].mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - req_ids_to_add.append(req_id) + reqs_to_add.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank @@ -563,11 +515,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the block IDs. if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: + assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids @@ -577,13 +531,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) + reqs_to_add.append(req_state) continue # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(new_block_ids, req_index) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -612,9 +568,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - self.input_batch.add_request(req_state) + for request in reqs_to_add: + self.input_batch.add_request(request) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() @@ -623,42 +578,66 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_item in req_state.mm_kwargs: + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput", ) -> BatchedTensorInputs: - if self.is_multimodal_raw_input_supported: # noqa: SIM102 - if scheduler_output: - mm_kwargs = list[MultiModalKwargsItem]() - for req in scheduler_output.scheduled_new_reqs: - req_mm_kwargs = req.mm_kwargs - if not isinstance(req_mm_kwargs, list): - req_mm_kwargs = list(req_mm_kwargs) - mm_kwargs.extend(req_mm_kwargs) + if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102 + return {} - # Input all modalities at once - mm_kwargs_combined: BatchedTensorInputs = {} - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ): - mm_kwargs_combined.update(mm_kwargs_group) + mm_kwargs = list[MultiModalKwargsItem]() + for req in scheduler_output.scheduled_new_reqs: + mm_kwargs.extend(req.mm_kwargs) - return mm_kwargs_combined + # Input all modalities at once + mm_kwargs_combined: BatchedTensorInputs = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + mm_kwargs_combined.update(mm_kwargs_group) - return {} + return mm_kwargs_combined def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: - if self.is_multimodal_raw_input_supported: - mm_budget = self.mm_budget - assert mm_budget is not None + if not self.is_multimodal_raw_input_supported: + return {} + mm_budget = self.mm_budget + assert mm_budget is not None - dummy_modality, _ = mm_budget.get_modality_with_max_tokens() - - return self._get_mm_dummy_batch(dummy_modality, num_seqs) - - return {} + dummy_modality = mm_budget.get_modality_with_max_tokens() + return self._get_mm_dummy_batch(dummy_modality, num_seqs) def _get_cumsum_and_arange( self, @@ -717,7 +696,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens) # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] + positions_np = self.positions.np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) @@ -740,7 +719,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + out=self.input_ids.cpu[:total_num_scheduled_tokens]) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) @@ -748,42 +727,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): total_num_scheduled_tokens) # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[0] = 0 + self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.copy_to_gpu() + query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] - self.seq_lens_np[:num_reqs] = ( + self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) + # Fill unused with 0 for full cuda graph mode. + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu() + seq_lens = self.seq_lens.gpu[:num_reqs] + max_seq_len = self.seq_lens.np[:num_reqs].max().item() # Copy the tensors to the GPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True) else: # Common case (1D positions) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) - - self.query_start_loc[:num_reqs + 1].copy_( - self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) - self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], - non_blocking=True) - - # Fill unused with 0 for full cuda graph mode. - self.seq_lens[num_reqs:].fill_(0) - # Note: pad query_start_loc to be non-decreasing, as kernels - # like FlashAttention requires that - self.query_start_loc[num_reqs + 1:].fill_( - self.query_start_loc_cpu[num_reqs].item()) - - query_start_loc = self.query_start_loc[:num_reqs + 1] - - spec_decode_common_attn_metadata = None + self.positions.copy_to_gpu(total_num_scheduled_tokens) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -836,46 +806,56 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata: dict[str, Any] = {} - # Prepare encoder attention metadata separately - # (encoder layers are not in KV cache groups) - if self.is_encoder_only_model: - - per_layer_metadata = \ - self._build_encoder_only_attn_metadata( - scheduler_output) - - # Add encoder attention metadata for all encoder layers - attention_layers = get_layers_from_vllm_config( - self.vllm_config, Attention) - for layer_name, attn_module in attention_layers.items(): - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - common_attn_metadata, encoder_attn_metadata =\ - per_layer_metadata[layer_name] - attn_metadata[layer_name] = encoder_attn_metadata + # Used in the below loop. + query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + spec_decode_common_attn_metadata = None # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + pin_memory=self.pin_memory, + device="cpu").to(self.device, non_blocking=True) + slot_mapping = torch.zeros((total_num_scheduled_tokens, ), + dtype=torch.int32, + pin_memory=self.pin_memory, + device="cpu").to(self.device, + non_blocking=True) + num_common_prefix_blocks = 0 + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] + slot_mapping = blk_table.slot_mapping[: + total_num_scheduled_tokens] - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = ( + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id]) common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, causal=True, @@ -892,8 +872,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], + num_common_prefix_blocks, kv_cache_group_spec.kv_cache_spec, builder, ) @@ -1060,9 +1039,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - req.mrope_positions[:,src_start:src_end] - + self.mrope_positions.cpu[:, dst_start:dst_end] = ( + req.mrope_positions[:, src_start:src_end]) mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1071,7 +1049,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dst_end = mrope_pos_ptr + completion_part_len MRotaryEmbedding.get_next_input_positions_tensor( - out=self.mrope_positions_np, + out=self.mrope_positions.np, out_offset=dst_start, mrope_position_delta=req.mrope_position_delta, context_len=num_computed_tokens + prompt_part_len, @@ -1135,7 +1113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] - draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] metadata = SpecDecodeMetadata( @@ -1152,17 +1130,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return - # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: + mm_hash = req_state.mm_hashes[mm_input_id] mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1195,15 +1174,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for output in curr_group_outputs: encoder_outputs.append(output) - # Cache the encoder outputs. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( + # Cache the encoder outputs by mm_hash + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( output, is_embed=pos_info.is_embed, ) @@ -1221,6 +1194,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1240,11 +1214,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) + num_encoder_tokens, + ) assert start_idx < end_idx - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - encoder_output = self.encoder_cache[req_id][i] + + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] @@ -1337,9 +1314,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): out_indices = [] # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.zeros_like(grammar_bitmask, - shape=(logits.shape[0], - grammar_bitmask.shape[1])) + sorted_bitmask = np.full(shape=(logits.shape[0], + grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype) cumulative_index = 0 seq = sorted(scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1]) @@ -1354,10 +1332,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask - # If the grammar bitmask and the logits have the same shape + # If the length of out indices and the logits have the same shape # we don't need to pass indices to the kernel, # since the bitmask is already aligned with the logits. - skip_out_indices = grammar_bitmask.shape[0] == logits.shape[0] + skip_out_indices = len(out_indices) == logits.shape[0] # Serialization of np.ndarray is much more efficient than a tensor, # so we receive it in that format. @@ -1422,7 +1400,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model, is_dummy, is_profile, - log_stats=self.parallel_config.eplb_log_balancedness, + log_stats=self.parallel_config.eplb_config.log_balancedness, ) def get_dp_padding(self, @@ -1462,31 +1440,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "Either all or none of the requests in" \ " a batch must be pooling request" - extracted_hidden_states = list( - torch.split(hidden_states[:num_scheduled_tokens], - num_scheduled_tokens_np.tolist())) - + hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.pooling_metadata + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), + device=hidden_states.device) + seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + # Pooling models D2H & synchronize occurs in pooler.py:build_output raw_pooler_output = self.model.pooler( - hidden_states=extracted_hidden_states, - pooling_metadata=pooling_metadata) + hidden_states=hidden_states, pooling_metadata=pooling_metadata) pooler_output: list[Optional[torch.Tensor]] = [] - seq_lens = self.seq_lens[:self.input_batch.num_reqs] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens, pooling_metadata.prompt_lens): + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - if seq_len == prompt_len: - pooler_output.append(raw_output.data.cpu()) - else: - pooler_output.append(None) + output = raw_output.data if seq_len == prompt_len else None + pooler_output.append(output) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, @@ -1511,7 +1485,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = (self._prepare_inputs(scheduler_output)) + max_query_len) = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -1549,7 +1523,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids[:num_scheduled_tokens], + input_ids=self.input_ids.gpu[:num_scheduled_tokens], multimodal_embeddings=mm_embeds or None, ) @@ -1568,13 +1542,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] + input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] + positions = self.mrope_positions.gpu[:, :num_input_tokens] else: - positions = self.positions[:num_input_tokens] + positions = self.positions.gpu[:num_input_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1600,6 +1574,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): batch_descriptor=batch_descriptor, ), self.maybe_get_kv_connector_output( scheduler_output) as kv_connector_output: + model_output = self.model( input_ids=input_ids, positions=positions, @@ -1714,7 +1689,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( hidden_states[:num_scheduled_tokens], - scheduler_output, + scheduler_output.num_scheduled_tokens, ) # Get the valid generated tokens. @@ -1722,7 +1697,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + valid_sampled_token_ids = self._to_list(sampled_token_ids) else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( @@ -1738,6 +1713,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. + req_ids = self.input_batch.req_ids for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): if not sampled_ids: continue @@ -1753,16 +1729,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if not self.speculative_config: - # Speculative decoding is not enabled. - spec_token_ids = None - else: + if self.speculative_config: assert spec_decode_common_attn_metadata is not None - spec_token_ids = self.propose_draft_token_ids( + self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, sampling_metadata, @@ -1779,7 +1752,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], @@ -1787,6 +1759,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_nans_in_logits=num_nans_in_logits, ) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + if self._draft_token_ids is None: + return None + req_ids = self.input_batch.req_ids + if isinstance(self._draft_token_ids, torch.Tensor): + draft_token_ids = self._draft_token_ids.tolist() + else: + draft_token_ids = self._draft_token_ids + self._draft_token_ids = None + return DraftTokenIds(req_ids, draft_token_ids) + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", @@ -1797,11 +1780,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, - ) -> list[list[int]]: + ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self.propose_ngram_draft_token_ids( + draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) @@ -1819,13 +1802,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] - spec_token_ids = self.drafter.propose( + draft_token_ids = self.drafter.propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. + req_ids = self.input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): if token_ids: @@ -1834,7 +1818,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: # Partial prefill (rare case). # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] + req_id = req_ids[i] req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -1846,9 +1830,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if spec_decode_metadata is None: # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[:num_scheduled_tokens] + target_positions = self.positions.gpu[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], @@ -1868,9 +1852,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.drafter.prepare_inputs( common_attn_metadata, num_rejected_tokens_cpu) - target_token_ids = self.input_ids[token_indices] + target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[token_indices] + target_positions = self.positions.gpu[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) @@ -1890,14 +1874,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, ) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids + return draft_token_ids def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], ) -> list[list[int]]: # TODO(woosuk): Optimize. + req_ids = self.input_batch.req_ids draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) @@ -1908,7 +1892,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Skip requests that require sampling parameters that are not # supported with speculative decoding. - req_id = self.input_batch.req_ids[i] + req_id = req_ids[i] if req_id in self.input_batch.spec_decode_unsupported_reqs: draft_token_ids.append([]) continue @@ -1956,7 +1940,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): global_expert_load, old_global_expert_indices = ( EplbState.recv_state()) num_logical_experts = global_expert_load.shape[1] - self.parallel_config.num_redundant_experts = ( + self.parallel_config.eplb_config.num_redundant_experts = ( num_local_physical_experts * new_ep_size - num_logical_experts) assert old_global_expert_indices.shape[ 1] % num_local_physical_experts == 0 @@ -2056,7 +2040,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, - scheduler_output: "SchedulerOutput", + num_scheduled_tokens: dict[str, int], ) -> dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: @@ -2069,8 +2053,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # maintainable loop over optimal performance. completed_prefill_reqs = [] for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): - - num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_tokens = num_scheduled_tokens[req_id] # Get metadata for this request. request = self.requests[req_id] @@ -2113,7 +2096,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If this is a partial request (i.e. chunked prefill), # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] - offset = self.query_start_loc_np[req_idx].item() + offset = self.query_start_loc.np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states, None) @@ -2186,12 +2169,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @functools.cache def rand_input_ids() -> torch.Tensor: return torch.randint_like( - self.input_ids, + self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), dtype=input_ids.dtype) - logger.debug("Randomizing dummy data for DP Rank") + logger.debug_once("Randomizing dummy data for DP Rank") input_ids.copy_(rand_input_ids()[:input_ids.size(0)], non_blocking=True) yield @@ -2203,19 +2186,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch return next(mm_kwargs_group for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - [dummy_mm_item] * max_items_per_batch, + dummy_mm_items, device=self.device, pin_memory=self.pin_memory, )) @@ -2241,7 +2228,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.FULL: Full cudagraph, attention metadata is needed. - force_attention: If True, always create attention metadata. Used to + force_attention: If True, always create attention metadata. Used to warm up attention backend when mode is NONE. uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. @@ -2298,29 +2285,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. - if force_attention or cudagraph_runtime_mode == \ - CUDAGraphMode.FULL: + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: attn_metadata = {} # Make sure max_model_len is used at the graph capture time. - self.seq_lens_np[:num_reqs] = self.max_model_len - self.seq_lens_np[num_reqs:] = 0 - self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], - non_blocking=True) + self.seq_lens.np[:num_reqs] = self.max_model_len + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + seq_lens=self.seq_lens.gpu[:num_reqs], + seq_lens_cpu=self.seq_lens.cpu[:num_reqs], num_computed_tokens_cpu=self.input_batch. num_computed_tokens_cpu_tensor[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=self.max_model_len, block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. @@ -2343,14 +2329,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): **self._dummy_mm_kwargs(num_reqs), } else: - input_ids = self.input_ids[:num_tokens] + input_ids = self.input_ids.gpu[:num_tokens] inputs_embeds = None model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] + positions = self.mrope_positions.gpu[:, :num_tokens] else: - positions = self.positions[:num_tokens] + positions = self.positions.gpu[:num_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -2446,7 +2432,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): output_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) try: sampler_output = self.sampler(logits=logits, @@ -2503,13 +2489,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - hidden_states_list = list( - torch.split(hidden_states, num_scheduled_tokens_list)) req_num_tokens = num_tokens // num_reqs dummy_prompt_lens = torch.tensor( - [h.shape[0] for h in hidden_states_list], - device=self.device, + num_scheduled_tokens_list, + device="cpu", ) dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), dtype=torch.int32, @@ -2526,8 +2510,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooling_params=[dummy_pooling_params] * num_reqs, ) + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) + try: - return model.pooler(hidden_states=hidden_states_list, + return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): @@ -2571,14 +2558,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] logger.info( "Encoder cache will be initialized with a budget of " @@ -2741,11 +2723,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ assert len(self.attn_groups) == 0, \ "Attention backends are already initialized" - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) def get_attn_backends_for_layers( layer_names: list[str] ) -> dict[type[AttentionBackend], list[str]]: + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than using @@ -2754,7 +2738,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # they are cached correctly, there will be different objects per # layer. for layer_name in layer_names: - attn_backend = attn_layers[layer_name].get_attn_backend() + attn_backend = layers[layer_name].get_attn_backend() key = attn_backend.full_cls_name() attn_backends[key] = attn_backend attn_backend_layers[key].append(layer_name) @@ -2783,69 +2767,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if isinstance(kv_cache_spec, AttentionSpec): - attn_backends = get_attn_backends_for_layers( - kv_cache_group_spec.layer_names) - # TODO(lucas): move `get_mamba_attn_backend` into the mamba - # layers like above - elif isinstance(kv_cache_spec, MambaSpec): - attn_backends = { - get_mamba_attn_backend(kv_cache_spec.mamba_type): - kv_cache_group_spec.layer_names - } - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") - + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) self.attn_groups.append( create_attn_groups(attn_backends, kv_cache_spec)) # Calculate reorder batch threshold (if neeeded) self.calculate_reorder_batch_threshold() - if len(self.attn_groups) > 0: - return - - # Check if model is encoder-only - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) - for layer_name, attn_module in attn_layers.items(): - - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - if attn_module.sliding_window is None: - attn_spec: AttentionSpec = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - else: - attn_spec = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) - attn_specs[attn_spec].append(layer_name) - - else: - raise ValueError("Expected only encoder-only layers") - - if len(attn_specs) > 0: - total_layers = 0 - for attn_spec, layer_names in attn_specs.items(): - - attn_backends = get_attn_backends_for_layers(layer_names) - total_layers += len(layer_names) - - self.attn_groups.append( - create_attn_groups(attn_backends, attn_spec)) - assert total_layers == len(attn_layers), \ - "All or none of the layers are expected to be encoder-only" - self.is_encoder_only_model = True - def initialize_cudagraph_capture(self) -> None: min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None @@ -2967,6 +2896,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, ) def _allocate_kv_cache_tensors( @@ -2991,7 +2922,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): layer_names = set() for group in kv_cache_config.kv_cache_groups: - layer_names.update(group.layer_names) + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) assert layer_names == set(kv_cache_raw_tensors.keys( )), "Some layers are not correctly initialized" return kv_cache_raw_tensors @@ -3029,6 +2963,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): attn_backend = group.backend for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // @@ -3090,40 +3026,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): raise NotImplementedError if has_attn and has_mamba: - self._verify_hybrid_attention_mamba_layout(kv_cache_config, - kv_cache_raw_tensors) + self._update_hybrid_attention_mamba_layout(kv_caches) return kv_caches - def _verify_hybrid_attention_mamba_layout( - self, kv_cache_config: KVCacheConfig, - kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: + def _update_hybrid_attention_mamba_layout( + self, kv_caches: dict[str, torch.Tensor]) -> None: """ - Verify that the KV cache memory layout is compatible for - models with both attention and mamba KV cache groups. + Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...). Args: - kv_cache_config: The KV cache config - kv_cache_raw_tensors: The KV cache buffer of each layer. + kv_caches: The KV cache buffer of each layer. """ for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): for layer_name in group.layer_names: - raw_tensor = kv_cache_raw_tensors[layer_name] - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) - if isinstance(kv_cache_spec, AttentionSpec): - - kv_cache_shape = group.backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - if kv_cache_shape[0] != num_blocks or kv_cache_shape[ - 1] != 2: - raise ValueError( - "Hybrid models in V1 require an attention " - "backend with kv_cache_shape=" - "(num_blocks, 2, ...). Please try setting " - "VLLM_ATTENTION_BACKEND=FLASHINFER") + kv_cache = kv_caches[layer_name] + if (isinstance(kv_cache_spec, AttentionSpec) + and kv_cache.shape[0] == 2): + assert kv_cache.shape[1] != 2, \ + "Fail to determine whether the layout is " \ + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + f"a tensor of shape {kv_cache.shape}" + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_(size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, + *kv_cache.stride()[2:])) def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: @@ -3150,6 +3079,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config.kv_cache_groups, kv_caches, self.attn_groups, + self.runner_only_attn_layers, ) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) @@ -3174,8 +3104,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config self.may_reinitialize_input_batch(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) @@ -3188,6 +3120,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -3277,64 +3236,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return kv_cache_spec - def _build_encoder_only_attn_metadata( - self, scheduler_output: "SchedulerOutput") -> \ - dict[str, tuple[CommonAttentionMetadata, Any]]: - """Prepare encoder attention metadata for encoder-only models. - - Args: - scheduler_output: Scheduler output - - Returns: - dict[str, Any]: Encoder attention metadata - """ - num_reqs = self.input_batch.num_reqs - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - max_num_scheduled_tokens = max(tokens) - - dummy_block_table = torch.zeros((num_reqs, 1), - dtype=torch.int32, - device=self.device) - dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ), - dtype=torch.int32, - device=self.device) - - group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]() - - for attn_group_list in self.attn_groups: - - assert len(attn_group_list) == 1 - attn_group = attn_group_list[0] - - # Use the first attention metadata builder - # to create encoder attention metadata - builder = attn_group.metadata_builder - - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - block_table_tensor=dummy_block_table, - slot_mapping=dummy_slot_mapping, - causal=False, - ) - - metadata = builder.build( - common_prefix_len=0, # No cascade for encoder - common_attn_metadata=common_metadata, - ) - - for layer_name in attn_group.layer_names: - group_metadata[layer_name] = (common_metadata, metadata) - - return group_metadata + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + # This is a short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/22754. + # `tolist` would trigger a cuda wise stream sync, which + # would block other copy ops from other cuda streams. + # A cuda event sync would avoid such a situation. Since + # this is in the critical path of every single model + # forward loop, this has caused perf issue for a disagg + # setup. + pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned.copy_(sampled_token_ids, non_blocking=True) + self.transfer_event.record() + self.transfer_event.synchronize() + return pinned.tolist() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 04de8d36680a4..c252193313344 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -28,7 +28,8 @@ from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, + ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -166,7 +167,7 @@ class Worker(WorkerBase): self.device = torch.device(f"cuda:{self.local_rank}") current_platform.set_device(self.device) - _check_if_gpu_supports_dtype(self.model_config.dtype) + current_platform.check_if_supports_dtype(self.model_config.dtype) gc.collect() torch.cuda.empty_cache() @@ -215,8 +216,7 @@ class Worker(WorkerBase): self.model_runner.update_config(overrides) def reload_weights(self) -> None: - with self._maybe_get_memory_pool_context(tag="weights"): - self.model_runner.reload_weights() + self.model_runner.reload_weights() @torch.inference_mode() def determine_available_memory(self) -> int: @@ -291,7 +291,6 @@ class Worker(WorkerBase): allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: - from contextlib import nullcontext context = nullcontext() with context: self.model_runner.initialize_kv_cache(kv_cache_config) @@ -311,6 +310,10 @@ class Worker(WorkerBase): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size, skip_eplb=True) + # Warmup and tune the kernels used during model execution before + # cuda graph capture. + kernel_warmup(self) + if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -335,9 +338,6 @@ class Worker(WorkerBase): self.model_runner._dummy_sampler_run( hidden_states=last_hidden_states) - # Warmup kernels used during model execution - kernel_warmup(self) - # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) @@ -386,6 +386,9 @@ class Worker(WorkerBase): assert isinstance(output, ModelRunnerOutput) return output + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + return self.model_runner.take_draft_token_ids() + def profile(self, is_start: bool = True): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") @@ -511,7 +514,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None new_physical_experts = \ self.model_runner.eplb_state.physical_to_logical_map.shape[1] - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - self.model_runner.eplb_state.logical_replica_count.shape[1]) global_expert_load = None @@ -527,7 +530,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=False) - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - global_expert_load.shape[1]) prepare_communication_buffer_for_model(self.model_runner.model) self.model_runner.model.update_physical_experts_metadata( @@ -610,23 +613,3 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 2fbdee4724e35..84ed46989ea97 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -8,6 +8,7 @@ from contextlib import contextmanager from typing import Union import numpy as np +import torch import torch.nn as nn from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig @@ -31,7 +32,8 @@ class LoRAModelRunnerMixin: def load_lora_model(self, model: nn.Module, model_config: ModelConfig, scheduler_config: SchedulerConfig, - lora_config: LoRAConfig, device: str) -> nn.Module: + lora_config: LoRAConfig, + device: torch.device) -> nn.Module: if not supports_lora(model): raise ValueError( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f7e68edba3a13..d364236604274 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -208,8 +208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Lazy initialization self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -292,8 +292,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model_config, self.scheduler_config, self.mm_registry, - max_model_len=self.max_model_len, - max_num_reqs=self.max_num_reqs, ) if self.supports_mm_inputs else None) if not self.use_spmd: @@ -344,7 +342,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -359,12 +356,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): removed_req_indices.append(req_index) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -396,6 +389,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): prompt_token_ids=new_req_data.prompt_token_ids, mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, + mm_hashes=new_req_data.mm_hashes, sampling_params=sampling_params, pooling_params=None, generator=None, @@ -418,11 +412,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the cached states. req_state.num_computed_tokens = num_computed_tokens if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: + assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids @@ -438,7 +434,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(new_block_ids, req_index) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -843,14 +841,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + # List of tuple (mm_hash, pos_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: + mm_hash = req_state.mm_hashes[mm_input_id] mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -893,15 +893,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE (NickLucche) here we diverge from logic in other runners, as we # assume to only have whole mm items to process. Hence we avoid the # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, + for (mm_hash, pos_info), output in zip( + mm_hashes_pos, encoder_outputs, ): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} assert pos_info.is_embed is None, "Expected all positions to be"\ " contiguous and embeddings." - self.encoder_cache[req_id][input_id] = output + self.encoder_cache[mm_hash] = output def _gather_mm_embeddings( self, @@ -914,6 +914,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid @@ -934,11 +935,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # in the decoder's KV cache. continue - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." assert pos_info.is_embed is None, "Expected all positions to"\ " be contiguous and embeddings." - encoder_output = self.encoder_cache[req_id][i] + encoder_output = self.encoder_cache[mm_hash] mm_embeds.append(encoder_output) return mm_embeds @@ -1145,7 +1148,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=None, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], @@ -1542,14 +1544,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] logger.info( "Encoder cache will be initialized with a budget of " @@ -1816,19 +1813,23 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch return next(grouped_mm_kwargs for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( - [dummy_mm_item] * max_items_per_batch, + dummy_mm_items, device=self.device, pin_memory=self.pin_memory, )) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 72e0e4230a017..9adf8a14213f3 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" + import os from typing import Any, Optional import torch import torch.distributed import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.debug.profiler as xp -import torch_xla.runtime as xr import vllm.envs as envs from vllm.config import VllmConfig @@ -21,19 +19,27 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv -from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats -from vllm.v1.worker.tpu_model_runner import TPUModelRunner from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) +if not USE_TPU_COMMONS: + logger.info("tpu_commons not found, using vLLM's TPUWorker.") + import torch_xla.core.xla_model as xm + import torch_xla.debug.profiler as xp + import torch_xla.runtime as xr + + from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT + from vllm.v1.worker.tpu_model_runner import TPUModelRunner + class TPUWorker: @@ -325,9 +331,7 @@ class TPUWorker: ensure_kv_transfer_initialized(vllm_config) -try: +if USE_TPU_COMMONS: from tpu_commons.worker import TPUWorker as TPUCommonsWorker + TPUWorker = TPUCommonsWorker # type: ignore -except ImportError: - logger.info("tpu_commons not found, using vLLM's TPUWorker.") - pass diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index e7079235d6510..f407534687662 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -10,9 +10,10 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import ModelConfig, SchedulerConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry from vllm.v1.attention.backends.utils import AttentionMetadataBuilder -from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec if TYPE_CHECKING: @@ -27,35 +28,36 @@ class MultiModalBudget: model_config: ModelConfig, scheduler_config: SchedulerConfig, mm_registry: MultiModalRegistry, - *, - max_model_len: int, - max_num_reqs: int, ) -> None: super().__init__() self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry + self.cache = cache = processor_only_cache_from_config( + model_config, mm_registry) - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, - mm_registry=mm_registry, + self.max_model_len = model_config.max_model_len + self.max_num_reqs = scheduler_config.max_num_seqs + + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, + cache=cache) + + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config, + cache=cache) + + encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, ) - self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_compute_budget = encoder_compute_budget self.encoder_cache_size = encoder_cache_size - self.max_model_len = max_model_len - self.max_num_reqs = max_num_reqs - - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) max_items_per_prompt_by_modality = dict[str, int]() max_items_per_batch_by_modality = dict[str, int]() - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) - for modality, max_tokens in max_tokens_by_modality.items(): ( max_items_per_prompt, @@ -69,15 +71,14 @@ class MultiModalBudget: self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality self.max_items_per_batch_by_modality = max_items_per_batch_by_modality - def get_modality_with_max_tokens(self) -> tuple[str, int]: + def get_modality_with_max_tokens(self) -> str: max_tokens_by_modality = self.max_tokens_by_modality - modality, max_tokens = max(max_tokens_by_modality.items(), - key=lambda item: item[1]) + modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1]) - return modality, max_tokens + return modality def get_encoder_budget(self) -> int: - return min(self.max_num_encoder_input_tokens, self.encoder_cache_size) + return min(self.encoder_compute_budget, self.encoder_cache_size) def get_max_items( self, @@ -208,6 +209,7 @@ def initialize_kv_cache_for_kv_sharing( kv_caches: dict[str, torch.Tensor], # Optional for now to avoid breaking TPU attn_groups: Optional[list[list[AttentionGroup]]] = None, + runner_only_attn_layers: Optional[set[str]] = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -225,26 +227,37 @@ def initialize_kv_cache_for_kv_sharing( Note that layers in shared_kv_cache_layers.keys() are not originally included as it only contains layers which have its own KV cache allocation. + attn_groups: Optional list of attention groups. Layers in the same KV + cache group may be placed in different attention groups if they + have different attention backends. Currently only provided by + GPU model runner. """ - # Record index of KV cache group for each layer that allocates a KV cache. - layer_to_kv_cache_group_idx: dict[str, int] = {} - for i, kv_cache_group in enumerate(kv_cache_groups): - for layer_name in kv_cache_group.layer_names: - layer_to_kv_cache_group_idx[layer_name] = i + # mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx) + layer_to_attn_group_idx: dict[str, tuple[int, int]] = {} + if attn_groups: + for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups): + for attn_group_idx, attn_group in enumerate(kv_attn_groups): + for layer_name in attn_group.layer_names: + layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, + attn_group_idx) + else: + for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups): + for layer_name in kv_cache_group.layer_names: + # attn group idx default to 0 if not provided + layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0) for layer_name, target_layer_name in shared_kv_cache_layers.items(): kv_caches[layer_name] = kv_caches[target_layer_name] - group_idx = layer_to_kv_cache_group_idx[target_layer_name] - kv_cache_groups[group_idx].layer_names.append(layer_name) + kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0] + kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name) - if attn_groups is not None: - assert len(attn_groups[group_idx]) == 1, ( - "Only one attention group per KV cache group is supported " - "for KV-cache sharing for now.") - # TODO(lucas): I think in the future the layers that re-use a - # KV cache will be in a different attention group so we can - # remove this code from here. - attn_groups[group_idx][0].layer_names.append(layer_name) + if attn_groups: + attn_group_idx = layer_to_attn_group_idx[target_layer_name][1] + attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append( + layer_name) + + if runner_only_attn_layers is not None: + runner_only_attn_layers.add(layer_name) def bind_kv_cache( diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 59f8d0fcf5bd9..fb892211f19db 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import contextmanager from typing import TYPE_CHECKING import torch @@ -22,7 +23,8 @@ class XPUModelRunner(GPUModelRunner): vllm_config: VllmConfig, device: torch.device, ): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) # FIXME: To be verified. self.cascade_attn_enabled = False @@ -31,3 +33,21 @@ class XPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: torch.xpu.synchronize() + + +@contextmanager +def _torch_cuda_wrapper(): + + class _EventPlaceholder: + + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + try: + # replace cuda Event with xpu Event, this should work by default + torch.cuda.Event = torch.xpu.Event + yield + finally: + # if anything goes wrong, just patch it with a placeholder + torch.cuda.Event = _EventPlaceholder diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 134d839252653..17288cda8eccf 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -145,6 +145,7 @@ class XPUWorker(Worker): ): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) + current_platform.check_if_supports_dtype(self.model_config.dtype) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( self.local_rank).total_memory diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index e49783ad9b244..3e1950798dbf6 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -149,9 +149,16 @@ class PoolingModelRunner( if not self.is_driver_worker: return [] + pooling_metadata = model_input.pooling_metadata + assert pooling_metadata is not None + + pooling_metadata.build_pooling_cursor( + num_scheduled_tokens=pooling_metadata.prompt_lens, + device=hidden_or_intermediate_states.device) + return [ self.model.pooler(hidden_states=hidden_or_intermediate_states, - pooling_metadata=model_input.pooling_metadata) + pooling_metadata=pooling_metadata) ] def make_model_input_from_broadcasted_tensor_dict( @@ -192,8 +199,9 @@ class PoolingModelRunner( pooling_params = seq_group_metadata.pooling_params assert pooling_params is not None - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") + + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" model = cast(VllmModelForPooling, self.model) to_update = model.pooler.get_pooling_updates(task) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9dfea947568d4..fc24d95b80f2c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -3,6 +3,7 @@ """A GPU worker class.""" import gc import os +from contextlib import nullcontext from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -77,7 +78,8 @@ class Worker(LocalOrDistributedWorkerBase): "eagle", "deepseek_mtp", "glm4_moe_mtp", - "mimo_mtp")) \ + "mimo_mtp", + "ernie_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner @@ -205,7 +207,6 @@ class Worker(LocalOrDistributedWorkerBase): "used for one instance per process.") context = allocator.use_memory_pool(tag="weights") else: - from contextlib import nullcontext context = nullcontext() with context: self.model_runner.load_model() @@ -329,7 +330,6 @@ class Worker(LocalOrDistributedWorkerBase): allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: - from contextlib import nullcontext context = nullcontext() with context: self._init_cache_engine() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index f1c9a0ab001e8..a1fa7f2cf7a2e 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -544,7 +544,7 @@ class WorkerWrapperBase: Arguments are passed to the worker class constructor. """ kwargs = all_kwargs[self.rpc_rank] - self.vllm_config = kwargs.get("vllm_config", None) + self.vllm_config = kwargs.get("vllm_config") assert self.vllm_config is not None, ( "vllm_config is required to initialize the worker") enable_trace_function_call_for_thread(self.vllm_config)