diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py
index 7045d8810493e..bbed80ebe8476 100644
--- a/.buildkite/generate_index.py
+++ b/.buildkite/generate_index.py
@@ -8,7 +8,8 @@ template = """
Links for vLLM
- {wheel}
+ {x86_wheel}
+ {arm_wheel}
"""
@@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel)
with open("index.html", "w") as f:
print(f"Generated index.html for {args.wheel}")
+ # sync the abi tag with .buildkite/scripts/upload-wheels.sh
+ if "x86_64" in filename:
+ x86_wheel = filename
+ arm_wheel = filename.replace("x86_64", "aarch64").replace(
+ "manylinux1", "manylinux2014"
+ )
+ elif "aarch64" in filename:
+ x86_wheel = filename.replace("aarch64", "x86_64").replace(
+ "manylinux2014", "manylinux1"
+ )
+ arm_wheel = filename
+ else:
+ raise ValueError(f"Unsupported wheel: {filename}")
# cloudfront requires escaping the '+' character
f.write(
- template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
+ template.format(
+ x86_wheel=x86_wheel,
+ x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
+ arm_wheel=arm_wheel,
+ arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
+ )
)
diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
deleted file mode 100644
index 56ec933c9cc0e..0000000000000
--- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
+++ /dev/null
@@ -1,12 +0,0 @@
-# For vllm script, with -t option (tensor parallel size).
-# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
-model_name: "HandH1998/QQQ-Llama-3-8b-g128"
-tasks:
-- name: "gsm8k"
- metrics:
- - name: "exact_match,strict-match"
- value: 0.419
- - name: "exact_match,flexible-extract"
- value: 0.416
-limit: 1000
-num_fewshot: 5
diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt
index 27a1a9a82bd35..37eeac85c933b 100644
--- a/.buildkite/lm-eval-harness/configs/models-large.txt
+++ b/.buildkite/lm-eval-harness/configs/models-large.txt
@@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
DeepSeek-V2-Lite-Chat.yaml
-Meta-Llama-3-8B-QQQ.yaml
diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh
index a67fc89d54e60..897f84d1e360d 100644
--- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh
+++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh
@@ -2,7 +2,7 @@
# We can use this script to compute baseline accuracy on GSM for transformers.
#
# Make sure you have lm-eval-harness installed:
-# pip install lm-eval==0.4.4
+# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
usage() {
echo``
diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
index b98d42aa7b822..792f355c47a51 100644
--- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
+++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
@@ -3,7 +3,7 @@
# We use this for fp8, which HF does not support.
#
# Make sure you have lm-eval-harness installed:
-# pip install lm-eval==0.4.4
+# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
usage() {
echo``
diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py
index 12c4ba6aa69a6..50431d0cd4c5e 100644
--- a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py
+++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py
@@ -3,44 +3,129 @@
import argparse
import json
import os
+from importlib import util
import pandas as pd
+plotly_found = util.find_spec("plotly.express") is not None
+
def compare_data_columns(
files, name_column, data_column, info_cols, drop_column, debug=False
):
- print("\ncompare_data_column: " + data_column)
+ """
+ Align concatenation by keys derived from info_cols instead of row order.
+ - Pick one canonical key list: subset of info_cols present in ALL files.
+ - For each file: set index to those keys, aggregate duplicates
+ - (mean for metric, first for names).
+ - Concat along axis=1 (indexes align), then reset_index so callers can
+ - group by columns.
+ - If --debug, add a _name column per file.
+ """
+ print("\ncompare_data_column:", data_column)
+
frames = []
raw_data_cols = []
compare_frames = []
+
+ # 1) choose a canonical key list from info_cols that exists in ALL files
+ cols_per_file = []
+ for f in files:
+ try:
+ df_tmp = pd.read_json(f, orient="records")
+ except Exception as err:
+ raise ValueError(f"Failed to read {f}") from err
+ cols_per_file.append(set(df_tmp.columns))
+
+ key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)]
+ if not key_cols:
+ # soft fallback: use any info_cols present in the first file
+ key_cols = [c for c in info_cols if c in list(cols_per_file[0])]
+ if not key_cols:
+ raise ValueError(
+ "No common key columns found from info_cols across the input files."
+ )
+
+ # 2) build a single "meta" block (keys as columns) once, aligned by the key index
+ meta_added = False
+
for file in files:
- data_df = pd.read_json(file)
- serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
- # Show all info columns in the first couple columns
- if not frames:
- for col in info_cols:
- if col not in serving_df.columns:
- print(f"Skipping missing column: {col}")
- continue
- frames.append(serving_df[col])
- # only show test name under debug mode
- if debug is True:
- serving_df = serving_df.rename(columns={name_column: file + "_name"})
- frames.append(serving_df[file + "_name"])
+ df = pd.read_json(file, orient="records")
- file = "/".join(file.split("/")[:-1])
- serving_df = serving_df.rename(columns={data_column: file})
- frames.append(serving_df[file])
- raw_data_cols.append(file)
- compare_frames.append(serving_df[file])
+ # Keep rows that actually have the compared metric (same as original behavior)
+ if drop_column in df.columns:
+ df = df.dropna(subset=[drop_column], ignore_index=True)
+
+ # Stabilize numeric key columns (harmless if missing)
+ for c in (
+ "Input Len",
+ "Output Len",
+ "TP Size",
+ "PP Size",
+ "# of max concurrency.",
+ "qps",
+ ):
+ if c in df.columns:
+ df[c] = pd.to_numeric(df[c], errors="coerce")
+
+ # Ensure all key columns exist
+ for c in key_cols:
+ if c not in df.columns:
+ df[c] = pd.NA
+
+ # Set index = key_cols and aggregate duplicates → unique MultiIndex
+ df_idx = df.set_index(key_cols, drop=False)
+
+ # meta (key columns), unique per key
+ meta = df_idx[key_cols]
+ if not meta.index.is_unique:
+ meta = meta.groupby(level=key_cols, dropna=False).first()
+
+ # metric series for this file, aggregated to one row per key
+ file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file)
+ s = df_idx[data_column]
+ if not s.index.is_unique:
+ s = s.groupby(level=key_cols, dropna=False).mean()
+ s.name = file_label # column label like original
+
+ # add meta once (from first file) so keys are the leftmost columns
+ if not meta_added:
+ frames.append(meta)
+ meta_added = True
+
+ # (NEW) debug: aligned test-name column per file
+ if debug and name_column in df_idx.columns:
+ name_s = df_idx[name_column]
+ if not name_s.index.is_unique:
+ name_s = name_s.groupby(level=key_cols, dropna=False).first()
+ name_s.name = f"{file_label}_name"
+ frames.append(name_s)
+
+ frames.append(s)
+ raw_data_cols.append(file_label)
+ compare_frames.append(s)
+
+ # Generalize ratio: for any file N>=2, add ratio (fileN / file1)
if len(compare_frames) >= 2:
- # Compare numbers among two files
- ratio_df = compare_frames[1] / compare_frames[0]
- frames.append(ratio_df)
- compare_frames.pop(1)
+ base = compare_frames[0]
+ current = compare_frames[-1]
+ ratio = current / base
+ ratio = ratio.mask(base == 0) # avoid inf when baseline is 0
+ ratio.name = f"Ratio 1 vs {len(compare_frames)}"
+ frames.append(ratio)
+ # 4) concat on columns with aligned MultiIndex;
+ # then reset_index to return keys as columns
concat_df = pd.concat(frames, axis=1)
+ concat_df = concat_df.reset_index(drop=True).reset_index()
+ if "index" in concat_df.columns:
+ concat_df = concat_df.drop(columns=["index"])
+
+ # Ensure key/info columns appear first (in your info_cols order)
+ front = [c for c in info_cols if c in concat_df.columns]
+ rest = [c for c in concat_df.columns if c not in front]
+ concat_df = concat_df[front + rest]
+
print(raw_data_cols)
return concat_df, raw_data_cols
@@ -67,6 +152,15 @@ def split_json_by_tp_pp(
df = pd.DataFrame(data)
+ # Keep only "serving" tests
+ name_col = next(
+ (c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None
+ )
+ if name_col:
+ df = df[
+ df[name_col].astype(str).str.contains(r"serving", case=False, na=False)
+ ].copy()
+
# Handle alias column names
rename_map = {
"tp_size": "TP Size",
@@ -181,7 +275,6 @@ if __name__ == "__main__":
f"Expected subset: {filtered_info_cols}, "
f"but DataFrame has: {list(output_df.columns)}"
)
-
output_df_sorted = output_df.sort_values(by=existing_group_cols)
output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False)
for name, group in output_groups:
@@ -189,8 +282,7 @@ if __name__ == "__main__":
text_file.write(html_msgs_for_data_cols[i])
text_file.write(html)
- if plot is True:
- import pandas as pd
+ if plot and plotly_found:
import plotly.express as px
df = group[raw_data_cols]
diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml
index 85d3e56387421..f96c38bf57db7 100644
--- a/.buildkite/release-pipeline.yaml
+++ b/.buildkite/release-pipeline.yaml
@@ -27,7 +27,12 @@ steps:
env:
DOCKER_BUILDKIT: "1"
+ - block: "Build CUDA 12.6 wheel"
+ key: block-build-cu126-wheel
+ depends_on: ~
+
- label: "Build wheel - CUDA 12.6"
+ depends_on: block-build-cu126-wheel
id: build-wheel-cuda-12-6
agents:
queue: cpu_queue_postmerge
@@ -68,7 +73,7 @@ steps:
queue: cpu_queue_postmerge
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
- label: "Annotate release workflow"
diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh
index 57a7bc4e5f5df..9dec9f8e9eb32 100644
--- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh
@@ -46,6 +46,11 @@ function cpu_tests() {
set -e
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
+ # Run kernel tests
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pytest -v -s tests/kernels/test_onednn.py"
+
# Run basic model test
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
@@ -99,4 +104,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 40 mins.
export -f cpu_tests
-timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
+timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh
index deb61a9bafab6..445cd2735c190 100644
--- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh
@@ -23,9 +23,13 @@ docker run \
--device /dev/dri \
-v /dev/dri/by-path:/dev/dri/by-path \
--entrypoint="" \
+ -e "HF_TOKEN=${HF_TOKEN}" \
+ -e "ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK}" \
--name "${container_name}" \
"${image_name}" \
- sh -c '
+ bash -c '
+ set -e
+ echo $ZE_AFFINITY_MASK
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
@@ -35,8 +39,8 @@ docker run \
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
pytest -v -s v1/structured_output
- pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py
- pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py
+ pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
+ pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
pytest -v -s v1/test_serial_utils.py
pytest -v -s v1/test_utils.py
pytest -v -s v1/test_metrics_reader.py
diff --git a/.buildkite/scripts/tpu/cleanup_docker.sh b/.buildkite/scripts/tpu/cleanup_docker.sh
index 209d9c4341cdd..740d81fb39bb0 100755
--- a/.buildkite/scripts/tpu/cleanup_docker.sh
+++ b/.buildkite/scripts/tpu/cleanup_docker.sh
@@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then
# Remove dangling images (those that are not tagged and not used by any container)
docker image prune -f
# Remove unused volumes / force the system prune for old images as well.
- docker volume prune -f && docker system prune --force --filter "until=72h" --all
+ docker volume prune -f && docker system prune --force --filter "until=24h" --all
echo "Docker images and volumes cleanup completed."
else
echo "Disk usage is below $threshold%. No cleanup needed."
diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh
index 037897e53dbef..745f285c008ad 100644
--- a/.buildkite/scripts/upload-wheels.sh
+++ b/.buildkite/scripts/upload-wheels.sh
@@ -14,8 +14,19 @@ fi
# Get the single wheel file
wheel="${wheel_files[0]}"
-# Rename 'linux' to 'manylinux1' in the wheel filename
-new_wheel="${wheel/linux/manylinux1}"
+# Detect architecture and rename 'linux' to appropriate manylinux version
+arch=$(uname -m)
+if [[ $arch == "x86_64" ]]; then
+ manylinux_version="manylinux1"
+elif [[ $arch == "aarch64" ]]; then
+ manylinux_version="manylinux2014"
+else
+ echo "Warning: Unknown architecture $arch, using manylinux1 as default"
+ manylinux_version="manylinux1"
+fi
+
+# Rename 'linux' to the appropriate manylinux version in the wheel filename
+new_wheel="${wheel/linux/$manylinux_version}"
mv -- "$wheel" "$new_wheel"
wheel="$new_wheel"
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 4fc8857854927..df2735fefeedb 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -88,15 +88,6 @@ steps:
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
-- label: Chunked Prefill Test
- mirror_hardwares: [amdexperimental]
- source_file_dependencies:
- - vllm/
- - tests/basic_correctness/test_chunked_prefill
- commands:
- - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
-
- label: Core Test # 10min
mirror_hardwares: [amdexperimental]
fast_check: true
@@ -135,7 +126,8 @@ steps:
- tests/entrypoints/test_chat_utils
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
+ - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
+ - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py
- pytest -v -s entrypoints/test_chat_utils.py
- label: Distributed Tests (4 GPUs) # 10min
@@ -252,6 +244,7 @@ steps:
- pytest -v -s v1/core
- pytest -v -s v1/engine
- pytest -v -s v1/entrypoints
+ - pytest -v -s v1/executor
- pytest -v -s v1/sample
- pytest -v -s v1/logits_processors
- pytest -v -s v1/worker
@@ -295,15 +288,6 @@ steps:
- python3 offline_inference/basic/score.py
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
-- label: Prefix Caching Test # 9min
- mirror_hardwares: [amdexperimental]
- source_file_dependencies:
- - vllm/
- - tests/prefix_caching
- commands:
- - pytest -v -s prefix_caching
-
-
- label: Platform Tests (CUDA)
mirror_hardwares: [amdexperimental]
source_file_dependencies:
@@ -345,6 +329,7 @@ steps:
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py
+ - pytest -v -s compile/test_decorator.py
- label: PyTorch Fullgraph Smoke Test # 9min
mirror_hardwares: [amdexperimental]
@@ -358,6 +343,7 @@ steps:
- pytest -v -s compile/piecewise/test_simple.py
- pytest -v -s compile/piecewise/test_toy_llama.py
- pytest -v -s compile/piecewise/test_full_cudagraph.py
+ - pytest -v -s compile/piecewise/test_multiple_graphs.py
- label: PyTorch Fullgraph Test # 18min
mirror_hardwares: [amdexperimental]
@@ -468,13 +454,11 @@ steps:
- label: LM Eval Small Models # 53min
mirror_hardwares: [amdexperimental]
- working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
+ - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- label: OpenAI API correctness
mirror_hardwares: [amdexperimental]
@@ -562,6 +546,15 @@ steps:
commands:
- pytest -v -s models/language/pooling -m 'not core_model'
+- label: Multi-Modal Processor Test
+ source_file_dependencies:
+ - vllm/
+ - tests/models/multimodal
+ commands:
+ - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
+ - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
+ - pytest -v -s models/multimodal/processing/test_tensor_schema.py
+
- label: Multi-Modal Models Test (Standard)
mirror_hardwares: [amdexperimental]
torch_nightly: true
@@ -571,9 +564,7 @@ steps:
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pip freeze | grep -E 'torch'
- - pytest -v -s models/multimodal/processing
- - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
- - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
+ - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Models Test (Extended) 1
@@ -584,7 +575,7 @@ steps:
- tests/models/multimodal
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
+ - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing
- label: Multi-Modal Models Test (Extended) 2
mirror_hardwares: [amdexperimental]
@@ -647,8 +638,10 @@ steps:
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+ - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/fusion.py
+ - vllm/compilation/fusion_attn.py
commands:
- nvidia-smi
- python3 examples/offline_inference/basic/chat.py
@@ -663,8 +656,11 @@ steps:
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
+ - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py
+ - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
+ - pytest -v -s tests/kernels/moe/test_flashinfer.py
##### 1 GPU test #####
##### multi gpus test #####
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index b0dd5e99d4c72..ce9590f02ce71 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -10,6 +10,7 @@
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
+/vllm/model_executor/layers/mamba @tdoublep
/vllm/multimodal @DarkLight1337 @ywang96
/vllm/vllm_flash_attn @LucasWilkinson
/vllm/lora @jeejeelee
@@ -25,11 +26,11 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
# vLLM V1
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
/vllm/v1/structured_output @mgoin @russellb @aarnphm
+/vllm/v1/attention/backends/triton_attn.py @tdoublep
# Test ownership
/.buildkite/lm-eval-harness @mgoin @simon-mo
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
-/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
/tests/distributed/test_multi_node_assignment.py @youkaichao
/tests/distributed/test_pipeline_parallel.py @youkaichao
/tests/distributed/test_same_node.py @youkaichao
@@ -44,6 +45,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/structured_output @mgoin @russellb @aarnphm
/tests/weight_loading @mgoin @youkaichao @yewentao256
/tests/lora @jeejeelee
+/tests/models/language/generation/test_hybrid.py @tdoublep
# Docs
/docs @hmellor
@@ -72,3 +74,9 @@ mkdocs.yaml @hmellor
/vllm/model_executor/models/pixtral*.py @patrickvonplaten
/vllm/transformers_utils/configs/mistral.py @patrickvonplaten
/vllm/transformers_utils/tokenizers/mistral.py @patrickvonplaten
+
+# Kernels
+/vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep
+/vllm/attention/ops/triton_unified_attention.py @tdoublep
+
+
diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml
deleted file mode 100644
index 2b1086b7faf43..0000000000000
--- a/.github/workflows/lint-and-deploy.yaml
+++ /dev/null
@@ -1,89 +0,0 @@
-name: Lint and Deploy Charts
-
-on: pull_request
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: true
-
-permissions:
- contents: read
-
-jobs:
- lint-and-deploy:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- fetch-depth: 0
-
- - name: Set up Helm
- uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0
- with:
- version: v3.14.4
-
- #Python is required because ct lint runs Yamale and yamllint which require Python.
- - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
- with:
- python-version: '3.13'
-
- - name: Set up chart-testing
- uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0
- with:
- version: v3.10.1
-
- - name: Run chart-testing (lint)
- run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm
-
- - name: Setup minio
- run: |
- docker network create vllm-net
- docker run -d -p 9000:9000 --name minio --net vllm-net \
- -e "MINIO_ACCESS_KEY=minioadmin" \
- -e "MINIO_SECRET_KEY=minioadmin" \
- -v /tmp/data:/data \
- -v /tmp/config:/root/.minio \
- minio/minio server /data
- export AWS_ACCESS_KEY_ID=minioadmin
- export AWS_SECRET_ACCESS_KEY=minioadmin
- export AWS_EC2_METADATA_DISABLED=true
- mkdir opt-125m
- cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd ..
- aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket
- aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive
-
- - name: Create kind cluster
- uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0
-
- - name: Build the Docker image vllm cpu
- run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env .
-
- - name: Configuration of docker images, network and namespace for the kind cluster
- run: |
- docker pull amazon/aws-cli:2.6.4
- kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing
- kind load docker-image vllm-cpu-env:latest --name chart-testing
- docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")"
- kubectl create ns ns-vllm
-
- - name: Run chart-testing (install)
- run: |
- export AWS_ACCESS_KEY_ID=minioadmin
- export AWS_SECRET_ACCESS_KEY=minioadmin
- sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
- helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
-
- - name: curl test
- run: |
- kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 &
- sleep 10
- CODE="$(curl -v -f --location http://localhost:8001/v1/completions \
- --header "Content-Type: application/json" \
- --data '{
- "model": "opt-125m",
- "prompt": "San Francisco is a",
- "max_tokens": 7,
- "temperature": 0
- }'):$CODE"
- echo "$CODE"
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
deleted file mode 100644
index bfd02879965ee..0000000000000
--- a/.github/workflows/publish.yml
+++ /dev/null
@@ -1,111 +0,0 @@
-# This workflow will upload a Python Package to Release asset
-# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
-
-name: Create Release
-
-on:
- push:
- tags:
- - v*
-
-# Needed to create release and upload assets
-permissions:
- contents: write
-
-jobs:
- release:
- # Retrieve tag and create release
- name: Create Release
- runs-on: ubuntu-latest
- outputs:
- upload_url: ${{ steps.create_release.outputs.upload_url }}
- steps:
- - name: Checkout
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
-
- - name: Extract branch info
- shell: bash
- run: |
- echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV"
-
- - name: Create Release
- id: create_release
- uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
- env:
- RELEASE_TAG: ${{ env.release_tag }}
- with:
- github-token: "${{ secrets.GITHUB_TOKEN }}"
- script: |
- const script = require('.github/workflows/scripts/create_release.js')
- await script(github, context, core)
-
- # NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow.
- # wheel:
- # name: Build Wheel
- # runs-on: ${{ matrix.os }}
- # needs: release
-
- # strategy:
- # fail-fast: false
- # matrix:
- # os: ['ubuntu-20.04']
- # python-version: ['3.9', '3.10', '3.11', '3.12']
- # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt.
- # cuda-version: ['11.8', '12.1']
-
- # steps:
- # - name: Checkout
- # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
-
- # - name: Setup ccache
- # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
- # with:
- # create-symlink: true
- # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
-
- # - name: Set up Linux Env
- # if: ${{ runner.os == 'Linux' }}
- # run: |
- # bash -x .github/workflows/scripts/env.sh
-
- # - name: Set up Python
- # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- # with:
- # python-version: ${{ matrix.python-version }}
-
- # - name: Install CUDA ${{ matrix.cuda-version }}
- # run: |
- # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
-
- # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
- # run: |
- # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
-
- # - name: Build wheel
- # shell: bash
- # env:
- # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
- # run: |
- # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
- # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename)
- # asset_name=${wheel_name//"linux"/"manylinux1"}
- # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV"
- # echo "asset_name=${asset_name}" >> "$GITHUB_ENV"
-
- # - name: Upload Release Asset
- # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2
- # env:
- # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- # with:
- # upload_url: ${{ needs.release.outputs.upload_url }}
- # asset_path: ./dist/${{ env.wheel_name }}
- # asset_name: ${{ env.asset_name }}
- # asset_content_type: application/*
-
- # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
- # - name: Publish package
- # uses: pypa/gh-action-pypi-publish@release/v1.8
- # with:
- # repository-url: https://test.pypi.org/legacy/
- # password: ${{ secrets.PYPI_API_TOKEN }}
- # skip-existing: true
diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml
index 16ae1aadb96be..1ee605dc7bb0d 100644
--- a/.github/workflows/reminder_comment.yml
+++ b/.github/workflows/reminder_comment.yml
@@ -12,16 +12,43 @@ jobs:
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
script: |
- github.rest.issues.createComment({
- owner: context.repo.owner,
- repo: context.repo.repo,
- issue_number: context.issue.number,
- body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
- '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
- 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' +
- 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
- 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
- '🚀'
- })
+ try {
+ // Get the PR author
+ const prAuthor = context.payload.pull_request.user.login;
+
+ // Check if this is the author's first PR in this repository
+ // Use GitHub's search API to find all PRs by this author
+ const { data: searchResults } = await github.rest.search.issuesAndPullRequests({
+ q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`,
+ per_page: 100
+ });
+
+ const authorPRCount = searchResults.total_count;
+
+ console.log(`Found ${authorPRCount} PRs by ${prAuthor}`);
+
+ // Only post comment if this is the first PR (only one PR by this author)
+ if (authorPRCount === 1) {
+ console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`);
+ await github.rest.issues.createComment({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: context.issue.number,
+ body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
+ '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
+ 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' +
+ 'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' +
+ 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
+ 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
+ 'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' +
+ '🚀'
+ });
+ } else {
+ console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`);
+ }
+ } catch (error) {
+ console.error('Error checking PR history or posting comment:', error);
+ // Don't fail the workflow, just log the error
+ }
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 34386d670ac76..a1deefb07f09c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -30,7 +30,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
#
-set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
+set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12", "3.13")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
@@ -357,9 +357,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
set(MARLIN_SRCS
- "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
- "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
diff --git a/benchmarks/README.md b/benchmarks/README.md
index 1d715a193ea14..176b40212978f 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -32,6 +32,14 @@ become available.
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
wget http://images.cocodataset.org/zips/train2017.zip
+
+
+ | ShareGPT4Video (Video) |
+ ✅ |
+ ✅ |
+
+ git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video
+ |
| BurstGPT |
@@ -194,6 +202,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
```bash
vllm bench serve \
--backend openai-chat \
+ --endpoint-type openai-chat \
--model Qwen/Qwen2-VL-7B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name hf \
@@ -230,6 +239,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
```bash
vllm bench serve \
--backend openai-chat \
+ --endpoint-type openai-chat \
--model Qwen/Qwen2-VL-7B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name hf \
@@ -244,6 +254,7 @@ vllm bench serve \
```bash
vllm bench serve \
--backend openai-chat \
+ --endpoint-type openai-chat \
--model Qwen/Qwen2-VL-7B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name hf \
@@ -609,7 +620,7 @@ vllm bench serve \
--prefix-repetition-prefix-len 512 \
--prefix-repetition-suffix-len 128 \
--prefix-repetition-num-prefixes 5 \
- --prefix-repetition-output-len 128
+ --prefix-repetition-output-len 128
```
@@ -684,4 +695,31 @@ python benchmarks/benchmark_serving.py \
--endpoint /v1/chat/completion
```
+### Videos (ShareGPT4Video)
+
+Start vLLM:
+
+```bash
+python -m vllm.entrypoints.openai.api_server \
+ --model Qwen/Qwen2.5-VL-7B-Instruct \
+ --dtype bfloat16 \
+ --limit-mm-per-prompt '{"video": 1}' \
+ --allowed-local-media-path /path/to/sharegpt4video/videos
+```
+
+Send requests with videos:
+
+```bash
+python benchmarks/benchmark_serving.py \
+ --backend openai-chat \
+ --model Qwen/Qwen2.5-VL-7B-Instruct \
+ --dataset-name sharegpt \
+ --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \
+ --num-prompts 100 \
+ --save-result \
+ --result-dir ~/vllm_benchmark_results \
+ --save-detailed \
+ --endpoint /v1/chat/completion
+```
+
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index 1559ca2d92841..ba7c733be0b25 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -34,6 +34,7 @@ class RequestFuncInput:
multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False
language: Optional[str] = None
+ request_id: Optional[str] = None
@dataclass
@@ -71,6 +72,9 @@ async def async_request_tgi(
"inputs": request_func_input.prompt,
"parameters": params,
}
+ headers = None
+ if request_func_input.request_id:
+ headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
if request_func_input.ignore_eos:
@@ -82,7 +86,9 @@ async def async_request_tgi(
st = time.perf_counter()
most_recent_timestamp = st
try:
- async with session.post(url=api_url, json=payload) as response:
+ async with session.post(
+ url=api_url, json=payload, headers=headers
+ ) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
@@ -145,6 +151,9 @@ async def async_request_trt_llm(
}
if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len
+ headers = None
+ if request_func_input.request_id:
+ headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -152,7 +161,9 @@ async def async_request_trt_llm(
st = time.perf_counter()
most_recent_timestamp = st
try:
- async with session.post(url=api_url, json=payload) as response:
+ async with session.post(
+ url=api_url, json=payload, headers=headers
+ ) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
@@ -211,6 +222,8 @@ async def async_request_deepspeed_mii(
"top_p": 1.0,
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -283,6 +296,8 @@ async def async_request_openai_completions(
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -395,6 +410,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -491,6 +508,8 @@ async def async_request_openai_audio(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
# Send audio file
def to_bytes(y, sr):
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index 572292a5aca46..2ea4f9ccaff2b 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -19,6 +19,7 @@ import logging
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
+from copy import deepcopy
from dataclasses import dataclass
from functools import cache
from io import BytesIO
@@ -54,6 +55,7 @@ class SampleRequest:
expected_output_len: int
multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
lora_request: Optional[LoRARequest] = None
+ request_id: Optional[str] = None
# -----------------------------------------------------------------------------
@@ -155,7 +157,10 @@ class BenchmarkDataset(ABC):
@abstractmethod
def sample(
- self, tokenizer: PreTrainedTokenizerBase, num_requests: int
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ request_id_prefix: str = "",
) -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
@@ -167,6 +172,7 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
+ request_id_prefix (str) The prefix of request_id.
Returns:
list[SampleRequest]: A list of sample requests generated from the
@@ -175,7 +181,10 @@ class BenchmarkDataset(ABC):
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(
- self, requests: list[SampleRequest], num_requests: int
+ self,
+ requests: list[SampleRequest],
+ num_requests: int,
+ request_id_prefix: str = "",
) -> None:
"""
Oversamples the list of requests if its size is less than the desired
@@ -183,11 +192,18 @@ class BenchmarkDataset(ABC):
Args:
requests (List[SampleRequest]): The current list of sampled
- requests. num_requests (int): The target number of requests.
+ requests.
+ num_requests (int): The target number of requests.
+ request_id_prefix (str) The prefix of the request ids.
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
- additional = random.choices(requests, k=num_requests - len(requests))
+ additional = deepcopy(
+ random.choices(requests, k=num_requests - len(requests))
+ )
+ for i in range(len(additional)):
+ req = additional[i]
+ req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", num_requests)
@@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]:
)
+def process_video(video: Any) -> Mapping[str, Any]:
+ """
+ Process a single video input and return a multimedia content dictionary.
+
+ Supports the following input types:
+
+ 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
+ containing raw video data.
+
+ 2. String input: - Treats the string as a URL or local file path. -
+ Prepends "file://" if the string doesn't start with "http://" or
+ "file://". - Returns a dictionary with the image URL.
+
+ Raises:
+ ValueError: If the input is not a supported type.
+ """
+ if isinstance(video, dict) and "bytes" in video:
+ video_bytes = video["bytes"]
+ video_base64 = base64.b64encode(video_bytes).decode("utf-8")
+ return {
+ "type": "video_url",
+ "video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
+ }
+
+ if isinstance(video, str):
+ video_url = (
+ video if video.startswith(("http://", "file://")) else f"file://{video}"
+ )
+ return {"type": "video_url", "video_url": {"url": video_url}}
+
+ raise ValueError(
+ f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
+ )
+
+
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
@@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
+ request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
# Enforce range_ratio < 1
@@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
+ request_id=request_id_prefix + str(i),
)
)
+
return requests
@@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
samples: list = []
+ ind = 0
for entry in self.data:
if len(samples) >= num_requests:
break
@@ -430,9 +486,10 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check=output_len is not None,
):
continue
- # TODO: Also support ShareGPT4Video.
if image_path := entry.get("image"):
mm_content = process_image(image_path)
+ elif video_path := entry.get("video"):
+ mm_content = process_video(video_path)
else:
mm_content = None
if enable_multimodal_chat:
@@ -444,9 +501,11 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len=new_output_len,
lora_request=lora_request,
multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(ind),
)
)
- self.maybe_oversample_requests(samples, num_requests)
+ ind += 1
+ self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples
@@ -512,10 +571,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
sampled_requests = []
- for item in self.data:
+ for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["prompt"]
@@ -534,9 +594,12 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
+ request_id=request_id_prefix + str(i),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -578,6 +641,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
# Calculate average token length for a poem line.
@@ -603,6 +667,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines]
samples = []
+ ind = 0
while len(samples) < num_requests:
extra_lines = random.choices(
self.data, k=num_input_lines - num_prefix_lines
@@ -613,14 +678,17 @@ class SonnetDataset(BenchmarkDataset):
msg, add_generation_prompt=True, tokenize=False
)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
+
if prompt_len <= input_len:
samples.append(
SampleRequest(
prompt=prompt_formatted if return_prompt_formatted else prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
+ request_id=request_id_prefix + str(ind),
)
)
+ ind += 1
return samples
@@ -672,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
+ request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
samples = []
@@ -693,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len,
expected_output_len=output_len,
lora_request=lora_req,
+ request_id=request_id_prefix + str(i),
)
)
return samples
@@ -752,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
# Filter examples with at least 2 conversations
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
dynamic_output = output_len is None
+ ind = 0
for item in filtered_data:
if len(sampled_requests) >= num_requests:
@@ -785,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(ind),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ ind += 1
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -814,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
- for item in self.data:
+ for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
@@ -838,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(i),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -870,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
- for item in self.data:
+ for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
- prompt = f"{item['input']}\n\n{item['instruction']} Just output \
- the code, do not include any explanation."
+ prompt = (
+ f"{item['input']}\n\n{item['instruction']} Just output "
+ "the code, do not include any explanation."
+ )
# apply template
prompt = tokenizer.apply_chat_template(
@@ -892,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
+ request_id=request_id_prefix + str(i),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -924,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
- for item in self.data:
+ for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["turns"][0]
@@ -947,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
+ request_id=request_id_prefix + str(i),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -974,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
sampled_requests = []
dynamic_output = output_len is None
+ ind = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
@@ -1000,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
+ request_id=request_id_prefix + str(ind),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ ind += 1
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -1072,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset):
"zed-industries/zeta": _format_zeta_prompt,
}
- def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ request_id_prefix: str = "",
+ **kwargs,
+ ):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
- for sample in self.data:
+ for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample)
samples.append(
SampleRequest(
@@ -1086,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset):
expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids
),
+ request_id=request_id_prefix + str(i),
)
)
if len(samples) >= num_requests:
break
- self.maybe_oversample_requests(samples, num_requests)
+ self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples
@@ -1139,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
import librosa
@@ -1148,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset):
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
skipped = 0
+ ind = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
break
@@ -1166,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(ind),
)
)
+ ind += 1
if skipped:
logger.warning(
"%d samples discarded from dataset due to"
@@ -1175,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.",
skipped,
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index ae38caf7290b1..02f5f585c0c16 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -375,11 +375,12 @@ async def benchmark(
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
last_int_rps = current_int_rps
- prompt, prompt_len, output_len, mm_content = (
+ prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
+ request.request_id,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
@@ -397,6 +398,7 @@ async def benchmark(
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_body=extra_body,
+ request_id=request_id,
)
task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
tasks.append(asyncio.create_task(task))
@@ -665,6 +667,7 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer,
output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template,
+ request_id_prefix=args.request_id_prefix,
)
elif args.dataset_name == "sonnet":
@@ -678,6 +681,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=False,
+ request_id_prefix=args.request_id_prefix,
)
else:
assert tokenizer.chat_template or tokenizer.default_chat_template, (
@@ -690,6 +694,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=True,
+ request_id_prefix=args.request_id_prefix,
)
elif args.dataset_name == "hf":
@@ -751,6 +756,7 @@ def main(args: argparse.Namespace):
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.hf_output_len,
+ request_id_prefix=args.request_id_prefix,
)
else:
@@ -762,10 +768,15 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
+ request_id_prefix=args.request_id_prefix,
),
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
- ).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
+ ).sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ request_id_prefix=args.request_id_prefix,
+ ),
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
@@ -773,6 +784,7 @@ def main(args: argparse.Namespace):
input_len=args.random_input_len,
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
+ request_id_prefix=args.request_id_prefix,
),
}
@@ -1118,6 +1130,13 @@ def create_argument_parser():
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
+ parser.add_argument(
+ "--request-id-prefix",
+ type=str,
+ required=False,
+ default="benchmark-serving",
+ help="Specify the prefix of request id.",
+ )
# group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index c51b579686529..c7f290e1eb88e 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -597,8 +597,8 @@ def validate_args(args):
# https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1:
raise ValueError(
- "Data parallel is not supported in offline benchmark, \
- please use benchmark serving instead"
+ "Data parallel is not supported in offline benchmark, "
+ "please use benchmark serving instead"
)
diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
index 1d4e730f99ae9..a6b42406b5cb0 100644
--- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
@@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False
)
+ ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
+ ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
+ c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
+ c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
+
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@@ -111,6 +116,10 @@ def bench_run(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
+ ab_strides1: torch.Tensor,
+ ab_strides2: torch.Tensor,
+ c_strides1: torch.Tensor,
+ c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
@@ -125,6 +134,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
per_act_token,
a1_scale=None,
)
@@ -136,6 +149,10 @@ def bench_run(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
+ ab_strides1: torch.Tensor,
+ ab_strides2: torch.Tensor,
+ c_strides1: torch.Tensor,
+ c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
@@ -150,6 +167,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
per_act_token,
a1_scale=None,
)
@@ -194,6 +215,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
topk_weights,
topk_ids,
)
@@ -231,6 +256,10 @@ def bench_run(
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
+ "ab_strides1": ab_strides1,
+ "ab_strides2": ab_strides2,
+ "c_strides1": c_strides1,
+ "c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@@ -289,6 +318,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
topk_weights,
topk_ids,
per_act_token,
@@ -297,7 +330,7 @@ def bench_run(
results.append(
benchmark.Timer(
- stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
+ stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py
index 975d10f2e92ec..a9c4d30d9b189 100644
--- a/benchmarks/kernels/benchmark_machete.py
+++ b/benchmarks/kernels/benchmark_machete.py
@@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
else:
assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8
-
- if bt.w_ch_s is not None:
- s_ch = bt.w_ch_s.to(torch.float32)
- else:
- s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
-
- if bt.w_tok_s is not None:
- s_tok = bt.w_tok_s.to(torch.float32)
- else:
- s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
-
- fn = lambda: ops.marlin_qqq_gemm(
- a=bt.a,
- b_q_weight=w_q,
- s_group=w_s,
- s_tok=s_tok,
- s_ch=s_ch,
- workspace=workspace.scratch,
- size_m=bt.a.shape[0],
- size_n=bt.w_ref.shape[1],
- size_k=bt.w_ref.shape[0],
- )
+ raise NotImplementedError("QQQ is not supported anymore")
return fn
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index b4a03665ef10f..752c2d0082167 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -430,7 +430,6 @@ class BenchmarkWorker:
hidden_size,
topk,
dtype_str,
- is_marlin=False,
)
else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
new file mode 100644
index 0000000000000..0650cbf3cc18e
--- /dev/null
+++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import time
+
+import torch
+
+from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
+ silu_mul_fp8_quant_deep_gemm,
+)
+from vllm.platforms import current_platform
+
+
+def benchmark(E, T, H, G=128, runs=50):
+ current_platform.seed_everything(42)
+ y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
+ tokens_per_expert = torch.randint(
+ T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
+ )
+
+ # Warmup
+ for _ in range(10):
+ silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
+ torch.cuda.synchronize()
+
+ # Benchmark
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+ for _ in range(runs):
+ silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
+ torch.cuda.synchronize()
+
+ avg_time = (time.perf_counter() - start) / runs * 1000
+
+ # Calculate actual work done (only count valid tokens)
+ actual_tokens = tokens_per_expert.sum().item()
+ actual_elements = actual_tokens * H
+
+ # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
+ ops_per_element = 8
+ total_ops = actual_elements * ops_per_element
+ gflops = total_ops / (avg_time / 1000) / 1e9
+
+ # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
+ input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
+ output_bytes = actual_tokens * H * 1 # H fp8 outputs
+ scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
+ total_bytes = input_bytes + output_bytes + scale_bytes
+ memory_bw = total_bytes / (avg_time / 1000) / 1e9
+
+ return avg_time, gflops, memory_bw
+
+
+configs = [
+ (8, 32, 1024),
+ (16, 64, 2048),
+ (32, 128, 4096),
+ # DeepSeekV3 Configs
+ (256, 16, 7168),
+ (256, 32, 7168),
+ (256, 64, 7168),
+ (256, 128, 7168),
+ (256, 256, 7168),
+ (256, 512, 7168),
+ (256, 1024, 7168),
+]
+
+print(f"GPU: {torch.cuda.get_device_name()}")
+print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
+print("-" * 50)
+
+for E, T, H in configs:
+ try:
+ time_ms, gflops, gbps = benchmark(E, T, H)
+ print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
+ except Exception:
+ print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
index 77136edca45b5..72b54b40a2d1e 100644
--- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
@@ -3,16 +3,14 @@
import csv
import os
-import random
from datetime import datetime
+from typing import Optional
import flashinfer
import torch
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
-
-# KV Cache Layout for TRT-LLM
-# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
+FP8_DTYPE = torch.float8_e4m3fn
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -26,65 +24,107 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad()
def benchmark_decode(
- num_seqs,
- max_seq_len,
- page_size=16,
- dtype=torch.bfloat16,
- kv_layout="HND",
- num_kv_heads=8,
- kv_cache_dtype="auto",
- head_dim=128,
- warmup=10,
- trials=20,
+ dtype: torch.dtype,
+ quant_dtypes: tuple[
+ Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
+ ],
+ batch_size: int,
+ max_seq_len: int,
+ num_heads: tuple[int, int] = (64, 8),
+ head_size: int = 128,
+ kv_layout: str = "HND",
+ block_size: int = 16,
+ warmup: int = 10,
+ trials: int = 20,
):
torch.set_default_device("cuda")
- device = "cuda"
torch.manual_seed(0)
- HEAD_GRP_SIZE = 8
- MAX_SEQ_LEN = max_seq_len
+ q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
+ q_quant_dtype = q_quant_dtype or dtype
+ kv_quant_dtype = kv_quant_dtype or dtype
+ o_quant_dtype = o_quant_dtype or dtype
+
+ num_qo_heads, num_kv_heads = num_heads
+ assert num_qo_heads % num_kv_heads == 0
+
+ sm_scale = float(1.0 / (head_size**0.5))
# large number to reduce kv_cache reuse
- NUM_BLOCKS = int(256000 / page_size)
+ NUM_BLOCKS = int(256000 / block_size)
- workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
+ kv_cache_shape = None
+ if kv_layout == "NHD":
+ kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
+ elif kv_layout == "HND":
+ kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
+ else:
+ raise ValueError(f"Invalid kv_layout: {kv_layout}")
- # For decode, batch_size is num_decode_token
- num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
- sm_scale = float(1.0 / (head_dim**0.5))
- q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
- kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
+ query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
+ if q_quant_dtype == FP8_DTYPE:
+ query, q_scale = to_float8(query)
+ ref_query = query.to(dtype) * q_scale
+ else:
+ q_scale = 1.0
+ ref_query = query
- max_kv_len = max(kv_lens)
- kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
- max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
+ kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
+ kv_lens[-1] = max_seq_len
+ seq_lens = kv_lens
+ max_seq_len = torch.max(seq_lens).item()
+
+ kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
+ if kv_quant_dtype == FP8_DTYPE:
+ kv_cache, kv_scale = to_float8(kv_cache)
+ ref_kv_cache = kv_cache.to(dtype) * kv_scale
+ else:
+ kv_scale = 1.0
+ ref_kv_cache = kv_cache
+ k_scale = v_scale = kv_scale
+
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
- 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
+ 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
+ kv_indptr = [0]
+ kv_indices = []
+ kv_last_page_lens = []
+ for i in range(batch_size):
+ seq_len = seq_lens[i]
+ assert seq_len > 0
+ num_blocks = (seq_len + block_size - 1) // block_size
+ kv_indices.extend(block_tables[i, :num_blocks])
+ kv_indptr.append(kv_indptr[-1] + num_blocks)
+ kv_last_page_len = seq_len % block_size
+ if kv_last_page_len == 0:
+ kv_last_page_len = block_size
+ kv_last_page_lens.append(kv_last_page_len)
- kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
- kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
- k_scale = v_scale = 1.0
+ kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
+ kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
+ kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
+ workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
- if kv_cache_dtype.startswith("fp8"):
- kv_cache, _ = to_float8(kv_cache)
-
- output_trtllm = torch.empty(q.shape, dtype=dtype)
-
- # Benchmark TRT decode
- def trt_decode():
- return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
- q,
- kv_cache,
- workspace_buffer,
- block_tables,
- kv_lens_tensor,
- max_kv_len,
- bmm1_scale=k_scale * sm_scale,
- bmm2_scale=v_scale,
- out=output_trtllm,
- )
+ wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
+ workspace_buffer,
+ kv_layout,
+ use_tensor_cores=True,
+ )
+ wrapper.plan(
+ kv_indptr,
+ kv_indices,
+ kv_last_page_lens,
+ num_qo_heads,
+ num_kv_heads,
+ head_size,
+ block_size,
+ "NONE",
+ sm_scale=sm_scale,
+ q_data_type=dtype,
+ kv_data_type=dtype,
+ )
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
@@ -101,74 +141,51 @@ def benchmark_decode(
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
- # TRT Decode
- trt_mean, trt_std = time_fn(trt_decode)
-
- kv_indptr = [0]
- kv_indices = []
- kv_last_page_lens = []
- for i in range(num_seqs):
- seq_len = kv_lens[i]
- assert seq_len > 0
- num_blocks = (seq_len + page_size - 1) // page_size
- kv_indices.extend(block_tables[i, :num_blocks])
- kv_indptr.append(kv_indptr[-1] + num_blocks)
- kv_last_page_len = seq_len % page_size
- if kv_last_page_len == 0:
- kv_last_page_len = page_size
- kv_last_page_lens.append(kv_last_page_len)
-
- kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
- kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
- kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
-
- output_baseline = torch.empty(q.shape, dtype=dtype)
-
- wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
- workspace_buffer,
- kv_layout,
- use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
- )
-
- wrapper.plan(
- kv_indptr,
- kv_indices,
- kv_last_page_lens,
- num_qo_heads,
- num_kv_heads,
- head_dim,
- page_size,
- "NONE",
- q_data_type=dtype,
- kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
- )
+ o_scale = 1.0
+ output_baseline = torch.empty(ref_query.shape, dtype=dtype)
+ output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_decode():
- return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
+ return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
+
+ def trtllm_decode():
+ return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
+ query=query,
+ kv_cache=kv_cache,
+ workspace_buffer=workspace_buffer,
+ block_tables=block_tables,
+ seq_lens=seq_lens,
+ max_seq_len=max_seq_len,
+ bmm1_scale=q_scale * k_scale * sm_scale,
+ bmm2_scale=v_scale / o_scale,
+ out=output_trtllm,
+ )
baseline_mean, baseline_std = time_fn(baseline_decode)
+ trtllm_mean, trtllm_std = time_fn(trtllm_decode)
# Calculate percentage speedup (positive means TRT is faster)
- speedup_percent = (baseline_mean - trt_mean) / baseline_mean
+ speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
print(
- f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
+ f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
)
# Return results for CSV writing
return {
- "num_seqs": num_seqs,
- "trt_mean": trt_mean,
- "trt_std": trt_std.item(),
+ "batch_size": batch_size,
+ "trtllm_mean": trtllm_mean,
+ "trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
- "q_dtype": str(dtype),
- "kv_cache_dtype": kv_cache_dtype,
- "page_size": page_size,
+ "q_dtype": str(q_quant_dtype),
+ "kv_cache_dtype": str(kv_quant_dtype),
+ "output_dtype": str(o_quant_dtype),
+ "block_size": block_size,
"num_kv_heads": num_kv_heads,
- "head_dim": head_dim,
+ "head_size": head_size,
"max_seq_len": max_seq_len,
}
@@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [
- "num_seqs",
- "trt_mean",
- "trt_std",
+ "batch_size",
+ "trtllm_mean",
+ "trtllm_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
- "page_size",
+ "output_dtype",
+ "block_size",
"num_kv_heads",
- "head_dim",
+ "head_size",
"max_seq_len",
]
@@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):
if __name__ == "__main__":
- num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
+ batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
- print(
- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
- "output_dtype: bfloat16"
- )
- print(
- "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
- "baseline_std\tspeedup_percent"
- )
- for max_seq_len in max_seq_lens:
- for bs in num_seqs:
- result = benchmark_decode(
- bs,
- max_seq_len,
- dtype=torch.bfloat16,
- kv_cache_dtype="auto",
- )
- all_results.append(result)
+ dtype = torch.bfloat16
+ quant_dtypes = [
+ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
+ (None, None, None),
+ (None, FP8_DTYPE, None),
+ (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
+ ]
- print(
- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
- "output_dtype: bfloat16"
- )
- print(
- "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
- "baseline_std\tspeedup_percent"
- )
- for max_seq_len in max_seq_lens:
- for bs in num_seqs:
- result = benchmark_decode(
- bs,
- max_seq_len,
- dtype=torch.bfloat16,
- kv_cache_dtype="fp8",
- )
- all_results.append(result)
+ for quant_dtype in quant_dtypes:
+ q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
+ q_quant_dtype = q_quant_dtype or dtype
+ kv_quant_dtype = kv_quant_dtype or dtype
+ o_quant_dtype = o_quant_dtype or dtype
+
+ print(
+ f"Running benchmark for q_dtype = {q_quant_dtype}, "
+ f"kv_cache_dtype: {kv_quant_dtype}, "
+ f"output_dtype: {o_quant_dtype}"
+ )
+ print(
+ "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
+ "baseline_std\tspeedup_percent"
+ )
+ for max_seq_len in max_seq_lens:
+ for bs in batch_sizes:
+ result = benchmark_decode(
+ dtype=dtype,
+ quant_dtypes=quant_dtype,
+ batch_size=bs,
+ max_seq_len=max_seq_len,
+ )
+ all_results.append(result)
# Write all results to CSV
write_results_to_csv(all_results)
diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
index 67bd9aebbcca9..49810e20c7d82 100644
--- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
@@ -3,16 +3,14 @@
import csv
import os
-import random
from datetime import datetime
+from typing import Optional
import flashinfer
import torch
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
-
-# KV Cache Layout for TRT-LLM
-# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
+FP8_DTYPE = torch.float8_e4m3fn
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -26,84 +24,99 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad()
def benchmark_prefill(
- num_seqs,
- max_seq_len,
- page_size=16,
- dtype=torch.bfloat16,
- kv_layout="HND",
- num_kv_heads=8,
- kv_cache_dtype="auto",
- head_dim=128,
- warmup=10,
- trials=20,
+ dtype: torch.dtype,
+ quant_dtypes: tuple[
+ Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
+ ],
+ batch_size: int,
+ max_seq_len: int,
+ num_heads: tuple[int, int] = (64, 8),
+ head_size: int = 128,
+ kv_layout: str = "HND",
+ block_size: int = 16,
+ warmup: int = 10,
+ trials: int = 20,
):
torch.set_default_device("cuda")
torch.manual_seed(0)
- HEAD_GRP_SIZE = 8
- MAX_SEQ_LEN = max_seq_len
+ q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
+ q_quant_dtype = q_quant_dtype or dtype
+ kv_quant_dtype = kv_quant_dtype or dtype
+ o_quant_dtype = o_quant_dtype or dtype
+
+ max_q_len = max_kv_len = max_seq_len
+
+ num_qo_heads, num_kv_heads = num_heads
+ assert num_qo_heads % num_kv_heads == 0
+
+ sm_scale = float(1.0 / (head_size**0.5))
# large number to reduce kv_cache reuse
- NUM_BLOCKS = int(256000 / page_size)
+ NUM_BLOCKS = int(256000 / block_size)
- workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8)
+ kv_cache_shape = None
+ if kv_layout == "NHD":
+ kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
+ elif kv_layout == "HND":
+ kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
+ else:
+ raise ValueError(f"Invalid kv_layout: {kv_layout}")
- num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
- sm_scale = float(1.0 / (head_dim**0.5))
-
- q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
- q_lens[-1] = MAX_SEQ_LEN
- max_q_len = max(q_lens)
+ q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
+ q_lens[-1] = max_q_len
q_indptr = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
- torch.cumsum(
- torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
- ),
+ torch.cumsum(q_lens, dim=0, dtype=torch.int32),
]
)
- q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
- kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)]
- kv_lens[-1] = MAX_SEQ_LEN
+ query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
+ if q_quant_dtype == FP8_DTYPE:
+ query, q_scale = to_float8(query)
+ ref_query = query.to(dtype) * q_scale
+ else:
+ q_scale = 1.0
+ ref_query = query
- seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)]
- max_seq_len = max(seq_lens)
- seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
+ kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
+ kv_lens[-1] = max_kv_len
- max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size
+ seq_lens = kv_lens + q_lens
+ max_seq_len = torch.max(seq_lens).item()
+
+ kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
+ if kv_quant_dtype == FP8_DTYPE:
+ kv_cache, kv_scale = to_float8(kv_cache)
+ ref_kv_cache = kv_cache.to(dtype) * kv_scale
+ else:
+ kv_scale = 1.0
+ ref_kv_cache = kv_cache
+ k_scale = v_scale = kv_scale
+
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
- 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
+ 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
-
- kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
- kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
- k_scale = v_scale = 1.0
-
- if kv_cache_dtype.startswith("fp8"):
- kv_cache, _ = to_float8(kv_cache)
-
- output_trtllm = torch.empty(q.shape, dtype=dtype)
-
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
- for i in range(num_seqs):
+ for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
- num_blocks = (seq_len + page_size - 1) // page_size
+ num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
- kv_last_page_len = seq_len % page_size
+ kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
- kv_last_page_len = page_size
+ kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
-
- output_baseline = torch.empty(q.shape, dtype=dtype)
+ workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
@@ -115,12 +128,12 @@ def benchmark_prefill(
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
- head_dim,
- page_size,
+ head_size,
+ block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
- kv_data_type=kv_cache.dtype,
+ kv_data_type=dtype,
)
def time_fn(fn, warmup=10, trials=20):
@@ -138,52 +151,55 @@ def benchmark_prefill(
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
- def baseline_prefill():
- return wrapper.run(
- q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline
- )
+ o_scale = 1.0
+ output_baseline = torch.empty(ref_query.shape, dtype=dtype)
+ output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
- def trt_prefill():
+ def baseline_prefill():
+ return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
+
+ def trtllm_prefill():
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
- query=q,
+ query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
- seq_lens=seq_lens_tensor,
+ seq_lens=seq_lens,
max_q_len=max_q_len,
max_kv_len=max_seq_len,
- bmm1_scale=k_scale * sm_scale,
- bmm2_scale=v_scale,
- batch_size=num_seqs,
+ bmm1_scale=q_scale * k_scale * sm_scale,
+ bmm2_scale=v_scale / o_scale,
+ batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
out=output_trtllm,
)
- trt_mean, trt_std = time_fn(trt_prefill)
baseline_mean, baseline_std = time_fn(baseline_prefill)
+ trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
# Calculate percentage speedup (positive means TRT is faster)
- speedup_percent = (baseline_mean - trt_mean) / baseline_mean
+ speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
print(
- f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}"
- f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}"
+ f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
+ f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
)
# Return results for CSV writing
return {
- "num_seqs": num_seqs,
- "trt_mean": trt_mean,
- "trt_std": trt_std.item(),
+ "batch_size": batch_size,
+ "trtllm_mean": trtllm_mean,
+ "trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
- "q_dtype": str(dtype),
- "kv_cache_dtype": kv_cache_dtype,
- "page_size": page_size,
+ "q_dtype": str(q_quant_dtype),
+ "kv_cache_dtype": str(kv_quant_dtype),
+ "output_dtype": str(o_quant_dtype),
+ "block_size": block_size,
"num_kv_heads": num_kv_heads,
- "head_dim": head_dim,
+ "head_size": head_size,
"max_seq_len": max_seq_len,
}
@@ -195,17 +211,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [
- "num_seqs",
- "trt_mean",
- "trt_std",
+ "batch_size",
+ "trtllm_mean",
+ "trtllm_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
- "page_size",
+ "output_dtype",
+ "block_size",
"num_kv_heads",
- "head_dim",
+ "head_size",
"max_seq_len",
]
@@ -224,27 +241,41 @@ def write_results_to_csv(results, filename=None):
if __name__ == "__main__":
- num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
+ batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
- print(
- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
- "output_dtype: bfloat16"
- )
- print(
- "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
- "baseline_std\tspeedup_percent"
- )
- for max_seq_len in max_seq_lens:
- for bs in num_seqs:
- result = benchmark_prefill(
- bs,
- max_seq_len,
- dtype=torch.bfloat16,
- kv_cache_dtype="auto",
- )
- all_results.append(result)
+ dtype = torch.bfloat16
+ quant_dtypes = [
+ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
+ (None, None, None),
+ (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
+ ]
+
+ for quant_dtype in quant_dtypes:
+ q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
+ q_quant_dtype = q_quant_dtype or dtype
+ kv_quant_dtype = kv_quant_dtype or dtype
+ o_quant_dtype = o_quant_dtype or dtype
+
+ print(
+ f"Running benchmark for q_dtype = {q_quant_dtype}, "
+ f"kv_cache_dtype: {kv_quant_dtype}, "
+ f"output_dtype: {o_quant_dtype}"
+ )
+ print(
+ "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
+ "baseline_std\tspeedup_percent"
+ )
+ for max_seq_len in max_seq_lens:
+ for bs in batch_sizes:
+ result = benchmark_prefill(
+ dtype=dtype,
+ quant_dtypes=quant_dtype,
+ batch_size=bs,
+ max_seq_len=max_seq_len,
+ )
+ all_results.append(result)
# Write all results to CSV
write_results_to_csv(all_results)
diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md
index ae0866ae60751..7adf97bcf5622 100644
--- a/benchmarks/multi_turn/README.md
+++ b/benchmarks/multi_turn/README.md
@@ -5,11 +5,13 @@ The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `re
First start serving your model
```bash
-export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
+export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
-vllm serve $MODEL_NAME --disable-log-requests
+vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests
```
+The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
+
## Synthetic Multi-Turn Conversations
Download the following text file (used for generation of synthetic conversations)
@@ -26,10 +28,10 @@ But you may use other text files if you prefer (using this specific file is not
Then run the benchmarking script
```bash
-export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
+export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
-python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \
---num-clients 2 --max-active-conversations 6
+python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \
+--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6
```
You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
index 53c3207491d18..d23b7b6e4571d 100644
--- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py
+++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
@@ -825,9 +825,11 @@ def get_client_config(
# Arguments for API requests
chat_url = f"{args.url}/v1/chat/completions"
+ model_name = args.served_model_name if args.served_model_name else args.model
+
req_args = RequestArgs(
chat_url=chat_url,
- model=args.model,
+ model=model_name,
stream=not args.no_stream,
limit_min_tokens=args.limit_min_tokens,
limit_max_tokens=args.limit_max_tokens,
@@ -1247,9 +1249,19 @@ async def main() -> None:
default=0,
help="Seed for random number generators (default: 0)",
)
+
parser.add_argument(
"-m", "--model", type=str, required=True, help="Path of the LLM model"
)
+ parser.add_argument(
+ "--served-model-name",
+ type=str,
+ default=None,
+ help="The model name used in the API. "
+ "If not specified, the model name will be the "
+ "same as the ``--model`` argument. ",
+ )
+
parser.add_argument(
"-u",
"--url",
diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake
index e0da46e2accaa..cc38cd41a5b24 100644
--- a/cmake/cpu_extension.cmake
+++ b/cmake/cpu_extension.cmake
@@ -182,17 +182,17 @@ endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
# Flag to enable ACL kernels for AARCH64 platforms
-if ( VLLM_BUILD_ACL STREQUAL "ON")
+if (VLLM_BUILD_ACL STREQUAL "ON")
set(USE_ACL ON)
else()
set(USE_ACL OFF)
endif()
-if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
+if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
- GIT_TAG v3.8.1
+ GIT_TAG v3.9
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
@@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
endif()
set(ONEDNN_AARCH64_USE_ACL "ON")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
- endif()
+ endif()
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
@@ -217,38 +217,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
+ set(ONEDNN_VERBOSE "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
FetchContent_MakeAvailable(oneDNN)
-
- list(APPEND LIBS dnnl)
-elseif(POWER10_FOUND)
- FetchContent_Declare(
- oneDNN
- GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
- GIT_TAG v3.7.2
- GIT_PROGRESS TRUE
- GIT_SHALLOW TRUE
+ add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
+ target_include_directories(
+ dnnl_ext
+ PUBLIC ${oneDNN_SOURCE_DIR}/include
+ PUBLIC ${oneDNN_BINARY_DIR}/include
+ PRIVATE ${oneDNN_SOURCE_DIR}/src
)
-
- set(ONEDNN_LIBRARY_TYPE "STATIC")
- set(ONEDNN_BUILD_DOC "OFF")
- set(ONEDNN_BUILD_EXAMPLES "OFF")
- set(ONEDNN_BUILD_TESTS "OFF")
- set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
- set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
- set(ONEDNN_BUILD_GRAPH "OFF")
- set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
- set(ONEDNN_ENABLE_ITT_TASKS "OFF")
- set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
- set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
- set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
-
- set(DNNL_CPU_RUNTIME "OMP")
-
- FetchContent_MakeAvailable(oneDNN)
-
- list(APPEND LIBS dnnl)
+ target_link_libraries(dnnl_ext dnnl)
+ target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
+ list(APPEND LIBS dnnl_ext)
+ set(USE_ONEDNN ON)
+else()
+ set(USE_ONEDNN OFF)
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
@@ -275,7 +260,6 @@ set(VLLM_EXT_SRC
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp"
${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
@@ -289,14 +273,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
${VLLM_EXT_SRC})
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
endif()
-elseif(POWER10_FOUND)
- set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
- ${VLLM_EXT_SRC})
endif()
-if (ASIMD_FOUND)
+
+if(USE_ONEDNN)
set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
+ "csrc/cpu/dnnl_kernels.cpp"
${VLLM_EXT_SRC})
endif()
diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake
index ee6768bce26ca..02224cfe3ee81 100644
--- a/cmake/external_projects/flashmla.cmake
+++ b/cmake/external_projects/flashmla.cmake
@@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
- GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
+ GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
@@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
- ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
+ ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
- ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
+ ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
+ ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc/cutlass/include
- ${flashmla_SOURCE_DIR}/csrc/include)
+ ${flashmla_SOURCE_DIR}/csrc)
set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}"
diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
index e0e95d06290df..6dd6f269f3dc9 100644
--- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
+++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
@@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
- num_kv_splits, // split_kv
+ static_cast(num_kv_splits), // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
@@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
- arguments.split_kv = num_kv_splits;
+ arguments.split_kv = static_cast(num_kv_splits);
MlaSm100Type::Fmha::set_split_kv(arguments);
return MlaSm100Type::Fmha::get_workspace_size(arguments);
diff --git a/csrc/cache.h b/csrc/cache.h
index 0970b704be3ab..fb0c353b96137 100644
--- a/csrc/cache.h
+++ b/csrc/cache.h
@@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
-void gather_cache(
+void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, std::optional seq_starts = std::nullopt);
\ No newline at end of file
+ int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& scale,
+ std::optional seq_starts = std::nullopt);
\ No newline at end of file
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 131dcb15cd7e9..b3a985c2d5bbb 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
-template
-__global__ void gather_cache(
- const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
+template
+__global__ void gather_and_maybe_dequant_cache(
+ const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
@@ -634,6 +634,7 @@ __global__ void gather_cache(
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
+ const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
@@ -675,10 +676,16 @@ __global__ void gather_cache(
if (partial_block_size) full_blocks_end -= 1;
}
- auto copy_entry = [&](const scalar_t* __restrict__ _src,
+ auto copy_entry = [&](const cache_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
- for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
- _dst[i] = _src[i];
+ for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
+ _dst[i] = static_cast(_src[i]);
+ } else {
+ _dst[i] =
+ fp8::scaled_convert(_src[i], *scale);
+ }
+ }
};
for (int pid = split_start; pid < full_blocks_end; ++pid) {
@@ -705,25 +712,31 @@ __global__ void gather_cache(
} // namespace vllm
// Macro to dispatch the kernel based on the data type.
-#define CALL_GATHER_CACHE(CPY_DTYPE) \
- vllm::gather_cache<<>>( \
- reinterpret_cast(src_cache.data_ptr()), \
- reinterpret_cast(dst.data_ptr()), \
- block_table.data_ptr(), cu_seq_lens.data_ptr(), \
- block_size, entry_size, block_table_stride, cache_block_stride, \
- cache_entry_stride, dst_entry_stride, seq_starts_ptr);
+// SCALAR_T is the data type of the destination tensor.
+// CACHE_T is the stored data type of kv-cache.
+// KV_DTYPE is the real data type of kv-cache.
+#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
+ vllm::gather_and_maybe_dequant_cache \
+ <<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst.data_ptr()), \
+ block_table.data_ptr(), cu_seq_lens.data_ptr(), \
+ block_size, entry_size, block_table_stride, cache_block_stride, \
+ cache_entry_stride, dst_entry_stride, \
+ reinterpret_cast(scale.data_ptr()), seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
-void gather_cache(
+void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size,
+ int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& scale,
std::optional seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -761,20 +774,8 @@ void gather_cache(
dim3 grid(batch_size, num_splits);
dim3 block(1024);
- TORCH_CHECK(src_cache.dtype() == dst.dtype(),
- "src_cache and dst must have the same dtype");
-
- const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr;
- if (dtype_bits == 32) {
- CALL_GATHER_CACHE(uint32_t);
- } else if (dtype_bits == 16) {
- CALL_GATHER_CACHE(uint16_t);
- } else if (dtype_bits == 8) {
- CALL_GATHER_CACHE(uint8_t);
- } else {
- TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
- }
+ DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
}
diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp
index 3952c43cbc727..982f7c07a13bd 100644
--- a/csrc/cpu/cpu_types_x86.hpp
+++ b/csrc/cpu/cpu_types_x86.hpp
@@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec {
explicit FP16Vec16(const FP32Vec16&);
- void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
+ void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
@@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec {
explicit BF16Vec16(const FP32Vec16&);
- void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
+ void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
@@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec {
(__m128i)vec8_data.reg, 1)) {}
void save(void* ptr) const {
- *reinterpret_cast<__m256i*>(ptr) = reg_low;
- *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
+ _mm256_storeu_si256((__m256i*)ptr, reg_low);
+ _mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
}
};
#endif
diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp
new file mode 100644
index 0000000000000..f3f00edb36068
--- /dev/null
+++ b/csrc/cpu/dnnl_helper.cpp
@@ -0,0 +1,346 @@
+#include
+#include
+
+#include "common/memory_desc.hpp"
+#include "common/memory.hpp"
+
+#include "dnnl_helper.h"
+
+static dnnl::engine& default_engine() {
+ static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
+ return engine;
+}
+
+static dnnl::stream& default_stream() {
+ static dnnl::stream stream(default_engine());
+ return stream;
+}
+
+void release_dnnl_matmul_handler(int64_t handler) {
+ DNNLMatMulPrimitiveHandler* ptr =
+ reinterpret_cast(handler);
+ delete ptr;
+}
+
+template
+class DNNLPrimitiveCache {
+ public:
+ using cache_value_t = std::pair;
+ using result_value_t = VT;
+ using container_t = std::list;
+ using value_iterator_t = typename container_t::iterator;
+ using map_t = std::unordered_map;
+ using creator_t = VT (*)();
+
+ public:
+ DNNLPrimitiveCache(size_t capacity)
+ : capacity_(capacity),
+ values_(),
+ key_to_value_(std::min(256lu, capacity)) {
+ assert(capacity > 0);
+ }
+
+ template
+ result_value_t get_or_create(const KT& key, F&& creator) {
+ std::optional value = get_value(key);
+ if (value.has_value()) {
+ return value.value()->second;
+ } else {
+ return add_value({key, creator()})->second;
+ }
+ }
+
+ size_t size() const { return values_.size(); }
+
+ private:
+ void dump_data() {
+ std::stringstream ss;
+ ss << "table_id: " << std::hex << reinterpret_cast(this) << std::dec
+ << "\n";
+ ss << "container: [";
+ for (auto&& iter : values_) {
+ ss << "(" << iter.first << ", " << std::hex
+ << reinterpret_cast(iter.second.get()) << "), " << std::dec;
+ }
+ ss << "]\n";
+
+ ss << "map: [";
+ for (auto&& iter : key_to_value_) {
+ ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
+ << reinterpret_cast(iter.second->second.get()) << std::dec
+ << "), ";
+ }
+ ss << "]\n";
+ std::printf("%s\n", ss.str().c_str());
+ }
+
+ value_iterator_t add_value(cache_value_t&& new_value) {
+ if (size() == capacity_) {
+ cache_value_t& last_item = values_.back();
+ key_to_value_.erase(last_item.first);
+ values_.pop_back();
+ }
+
+ auto& added_value_ = values_.emplace_front(std::move(new_value));
+ key_to_value_.emplace(added_value_.first, values_.begin());
+ return values_.begin();
+ }
+
+ std::optional get_value(const KT& key) {
+ if (key_to_value_.size() > 0 && key == values_.begin()->first) {
+ return values_.begin();
+ }
+
+ auto value_map_iterator = key_to_value_.find(key);
+ if (value_map_iterator != key_to_value_.end()) {
+ values_.splice(values_.begin(), values_, value_map_iterator->second);
+ return value_map_iterator->second;
+ } else {
+ return {};
+ }
+ }
+
+ private:
+ const size_t capacity_;
+ container_t values_;
+ map_t key_to_value_;
+};
+
+DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
+ const Args& args, dnnl::memory::data_type b_type)
+ : b_n_size_(args.b_n_size),
+ b_n_stride_(args.b_n_stride),
+ b_k_size_(args.b_k_size),
+ b_k_stride_(args.b_k_stride),
+ b_type_(b_type),
+ c_type_(args.c_type),
+ runtime_memory_ptrs_(8),
+ primitive_cache_size_(args.primitive_cache_size) {
+ assert(primitive_cache_size_ > 0);
+}
+
+void DNNLMatMulPrimitiveHandler::prepack_weight(
+ void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
+ dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
+ {b_k_stride_, b_n_stride_});
+ dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
+ dnnl::memory packed_weight(b_target_mem_desc, default_engine());
+ {
+ dnnl::reorder(original_weight, packed_weight)
+ .execute(default_stream(), original_weight, packed_weight);
+ default_stream().wait();
+ }
+ memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
+ b_target_mem_desc_ = b_target_mem_desc;
+}
+
+void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
+ size_t index, dnnl_memory* memory_ptr) {
+ dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
+ dnnl_memory_desc* mem_desc = const_cast(memory_ptr->md());
+ runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
+}
+
+std::pair
+DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
+ return runtime_memory_ptrs_[index];
+}
+
+namespace std {
+template <>
+struct hash {
+ size_t operator()(
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
+ return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^
+ hash()(static_cast(val.a_qs)) ^
+ hash()(static_cast(val.b_qs)) ^ hash()(val.use_azp) ^
+ hash()(static_cast(val.c_type));
+ }
+};
+
+template <>
+struct hash {
+ size_t operator()(
+ const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
+ return hash()(val.a_m_size) ^ hash()(val.use_bias) ^
+ hash()(static_cast(val.bias_type));
+ }
+};
+} // namespace std
+
+bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
+ return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
+ l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
+ l.c_type == r.c_type;
+}
+
+bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
+ const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
+ return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
+ l.bias_type == r.bias_type;
+}
+
+static std::shared_ptr
+get_w8a8_class_primitive_cache(
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
+ int64_t cache_size) {
+ static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
+ assert(cache_size > 0);
+ return cache.get_or_create(key, [&]() {
+ return std::make_shared(cache_size);
+ });
+}
+
+W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
+ : DNNLMatMulPrimitiveHandler(
+ static_cast(args),
+ dnnl::memory::data_type::s8),
+ use_azp_(args.use_a_zero_point),
+ a_qs_(args.a_quantization_strategy),
+ b_qs_(args.b_quantization_strategy),
+ m_size_cache_(nullptr) {
+ assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
+ assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
+ if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
+ assert(!use_azp_);
+ };
+ prepack_weight(args.b_ptr,
+ create_primitive_desc(
+ MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
+ .use_bias = false,
+ .bias_type = dnnl::memory::data_type::undef},
+ true)
+ .weights_desc());
+ init_runtime_memory_cache(args);
+}
+
+void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
+ auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
+ auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
+ a_storage->set_data_handle((void*)args.a_ptr);
+ a_mem_desc->dims[0] = args.a_m_size;
+ c_storage->set_data_handle((void*)args.c_ptr);
+ c_mem_desc->dims[0] = args.a_m_size;
+
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
+ a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
+ }
+ if (use_azp_) {
+ auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
+ get_runtime_memory_ptr(3);
+ a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
+ }
+
+ if (args.use_bias) {
+ auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
+ bias_storage->set_data_handle((void*)args.bias_ptr);
+ }
+
+ dnnl::matmul matmul = get_matmul_cache(args);
+ matmul.execute(default_stream(), memory_cache_);
+ default_stream().wait();
+}
+
+dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
+ const MSizeCacheKey& key) {
+ if (m_size_cache_.get() == nullptr) {
+ ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
+ .b_k_size = b_k_size_,
+ .a_qs = a_qs_,
+ .b_qs = b_qs_,
+ .use_azp = use_azp_,
+ .c_type = c_type_};
+ m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
+ }
+
+ return m_size_cache_->get_or_create(key, [&]() {
+ dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
+ return dnnl::matmul(desc);
+ });
+}
+
+void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
+ memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
+ dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::ab},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
+ memory_cache_[DNNL_ARG_DST] =
+ dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
+
+ // For PER_TOKEN, scales will be applied in outside epilogue
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
+ {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
+ set_runtime_memory_ptr(
+ 2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
+ if (use_azp_) {
+ memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
+ {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
+ set_runtime_memory_ptr(
+ 3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
+ }
+ }
+
+ if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
+ dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
+ (void*)args.b_scales_ptr);
+ } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
+ dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
+ default_engine(), (void*)args.b_scales_ptr);
+ }
+
+ memory_cache_[DNNL_ARG_BIAS] =
+ dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
+}
+
+dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
+ const MSizeCacheKey& key, bool first_time) {
+ dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
+ dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::ab);
+ dnnl::memory::desc b_md;
+ if (first_time) {
+ b_md =
+ dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::any);
+ } else {
+ b_md = b_target_mem_desc_;
+ }
+ dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
+ dnnl::memory::format_tag::ab);
+
+ dnnl::primitive_attr attr;
+ // For PER_TOKEN, scales will be applied in outside epilogue
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ attr.set_scales_mask(DNNL_ARG_SRC, 0);
+ if (use_azp_) {
+ attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
+ }
+ }
+
+ if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
+ attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
+ } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
+ attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
+ }
+
+ if (key.use_bias) {
+ // For PER_TOKEN, bias will be applied in epilogue
+ assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
+ dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
+ return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
+ c_md, attr);
+ } else {
+ return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
+ attr);
+ }
+}
diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h
new file mode 100644
index 0000000000000..54ceefced9e98
--- /dev/null
+++ b/csrc/cpu/dnnl_helper.h
@@ -0,0 +1,169 @@
+#ifndef DNNL_HELPER_H
+#define DNNL_HELPER_H
+
+#include
+#include
+
+#include "oneapi/dnnl/dnnl.hpp"
+
+namespace c10 {
+struct BFloat16;
+struct Half;
+} // namespace c10
+
+namespace dnnl {
+namespace impl {
+struct memory_storage_t;
+struct matmul_pd_t;
+struct matmul_desc_t;
+} // namespace impl
+} // namespace dnnl
+struct dnnl_memory_desc;
+
+template
+class DNNLPrimitiveCache;
+
+template
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type =
+ dnnl::memory::data_type::undef;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
+};
+
+template
+constexpr inline dnnl::memory::data_type get_dnnl_type() {
+ return DNNLType>::type;
+}
+
+class DNNLMatMulPrimitiveHandler {
+ public:
+ virtual ~DNNLMatMulPrimitiveHandler() = default;
+
+ protected:
+ struct Args {
+ dnnl_dim_t b_n_size;
+ dnnl_dim_t b_n_stride;
+ dnnl_dim_t b_k_size;
+ dnnl_dim_t b_k_stride;
+ void* b_ptr;
+ dnnl::memory::data_type c_type;
+ size_t primitive_cache_size;
+ };
+
+ protected:
+ DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
+
+ void prepack_weight(void* original_b_ptr,
+ dnnl::memory::desc b_target_mem_desc);
+
+ void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
+
+ std::pair
+ get_runtime_memory_ptr(size_t index);
+
+ protected:
+ const dnnl_dim_t b_n_size_;
+ const dnnl_dim_t b_n_stride_;
+ const dnnl_dim_t b_k_size_;
+ const dnnl_dim_t b_k_stride_;
+ dnnl::memory::data_type b_type_;
+ dnnl::memory::data_type c_type_;
+ std::unordered_map memory_cache_;
+ std::vector>
+ runtime_memory_ptrs_;
+ dnnl::memory::desc b_target_mem_desc_;
+ int64_t primitive_cache_size_;
+};
+
+class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
+ public:
+ enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL };
+
+ struct Args : public DNNLMatMulPrimitiveHandler::Args {
+ bool use_a_zero_point;
+ QuantizationStrategy a_quantization_strategy;
+ QuantizationStrategy b_quantization_strategy;
+ float* b_scales_ptr;
+ };
+
+ struct ClassMatmulCacheKey {
+ dnnl_dim_t b_n_size;
+ dnnl_dim_t b_k_size;
+ QuantizationStrategy a_qs;
+ QuantizationStrategy b_qs;
+ bool use_azp;
+ dnnl::memory::data_type c_type;
+
+ friend bool operator==(const ClassMatmulCacheKey& l,
+ const ClassMatmulCacheKey& r);
+ };
+
+ struct MSizeCacheKey {
+ dnnl_dim_t a_m_size;
+ bool use_bias;
+ dnnl::memory::data_type bias_type;
+
+ friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
+ };
+
+ using MSizeCache = DNNLPrimitiveCache;
+ using ClassMatmulCache =
+ DNNLPrimitiveCache>;
+
+ struct ExecArgs : public MSizeCacheKey {
+ const int8_t* a_ptr;
+ const float* a_scales_ptr;
+ const int32_t* a_zero_points_ptr;
+ const void* bias_ptr;
+ void* c_ptr;
+ };
+
+ public:
+ W8A8MatMulPrimitiveHandler(const Args& args);
+
+ QuantizationStrategy get_input_scale_strategy() const { return a_qs_; }
+
+ bool get_input_use_zero_point() const { return use_azp_; }
+
+ void execute(ExecArgs& args);
+
+ private:
+ dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
+ bool first_time);
+
+ void init_runtime_memory_cache(const Args& args);
+
+ dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
+
+ private:
+ const bool use_azp_;
+ const QuantizationStrategy a_qs_;
+ const QuantizationStrategy b_qs_;
+ std::shared_ptr m_size_cache_;
+};
+
+#endif
diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp
deleted file mode 100644
index 1cb8dc5b25a66..0000000000000
--- a/csrc/cpu/dnnl_helper.hpp
+++ /dev/null
@@ -1,206 +0,0 @@
-#ifndef DNNL_HELPER_HPP
-#define DNNL_HELPER_HPP
-
-#include
-#include
-
-#include "oneapi/dnnl/dnnl.hpp"
-
-namespace {
-template
-struct DNNLType {
- static constexpr dnnl::memory::data_type type =
- dnnl::memory::data_type::undef;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
-};
-
-template
-constexpr inline dnnl::memory::data_type get_dnnl_type() {
- return DNNLType>::type;
-}
-}; // namespace
-
-template
-class DNNLPrimitiveHelper {
- public:
- // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
- // A: [M, K], row-major
- // B: [K, N], column-major
- // C: [M, N], row-major
- // bias: [N], row-major, optional
- // a_scales: [MS]
- // b_scales: [NS]
- // Note: Due to the limitation of oneDNN
- // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
- // not supported.
-
- template
- static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
- const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
- dnnl_dim_t K, const float* a_scales,
- const float* b_scales, dnnl_dim_t MS,
- dnnl_dim_t NS) {
- auto&& OutputType = get_dnnl_type();
- auto&& BiasType = get_dnnl_type();
-
- dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
- dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
- dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
-
- dnnl::primitive_attr attr;
- if constexpr (!InputNoScale) {
- if (MS == 1) {
- // per-tensor
- attr.set_scales_mask(DNNL_ARG_SRC, 0);
- } else {
- // per-token
- TORCH_CHECK(false, "per-token quantization is unsupported.");
- }
- }
-
- if (NS == 1) {
- // per-tensor
- attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
- } else {
- // per-channel
- attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
- }
-
- dnnl::matmul::primitive_desc matmul_pd;
-// Create memory descriptors with format_tag::any for the primitive. This
-// enables the matmul primitive to choose memory layouts for an
-// optimized primitive implementation, and these layouts may differ from the
-// ones provided by the user.
-#ifdef __aarch64__
- auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
- dnnl::memory::format_tag::any);
- auto mat_weights_md = dnnl::memory::desc(
- {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
- auto mat_dst_md =
- dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
- if (bias) {
- dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
- mat_weights_md, bias_md,
- mat_dst_md, attr);
- } else {
- matmul_pd = dnnl::matmul::primitive_desc(
- default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
- }
-#else
- if (bias) {
- dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
- bias_md, c_md, attr);
- } else {
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
- c_md, attr);
- }
-#endif
- dnnl::matmul matmul(matmul_pd);
-
- auto& engine = default_engine();
-
- dnnl::memory a_m(a_md, engine, (void*)a);
- dnnl::memory b_m(b_md, engine, (void*)b);
- dnnl::memory c_m(c_md, engine, (void*)c);
- dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
- (void*)a_scales);
- dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
- (void*)b_scales);
-
- auto& stream = default_stream();
-
- auto mat_src_mem = a_m;
- auto mat_weights_mem = b_m;
- auto mat_dst_mem = c_m;
-#ifdef __aarch64__
- if (matmul_pd.weights_desc() != b_m.get_desc()) {
- mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
- dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
- }
-#endif
- if constexpr (InputNoScale) {
- if (bias) {
- dnnl::memory::desc bias_md({N}, BiasType, {1});
- dnnl::memory bias_m(bias_md, engine, (void*)bias);
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_BIAS, bias_m},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- } else {
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- }
- } else {
- if (bias) {
- dnnl::memory::desc bias_md({N}, BiasType, {1});
- dnnl::memory bias_m(bias_md, engine, (void*)bias);
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_BIAS, bias_m},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- } else {
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- }
- }
- stream.wait();
- }
-
- private:
- static dnnl::engine& default_engine() {
- static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
- return engine;
- }
-
- static dnnl::stream& default_stream() {
- static dnnl::stream stream(default_engine());
- return stream;
- }
-};
-#endif
diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp
new file mode 100644
index 0000000000000..acc3b9ecde143
--- /dev/null
+++ b/csrc/cpu/dnnl_kernels.cpp
@@ -0,0 +1,494 @@
+#include "cpu_types.hpp"
+#include "dnnl_helper.h"
+
+namespace {
+template
+struct KernelVecType {
+ using load_vec_type = void;
+ using cvt_vec_type = void;
+};
+
+template <>
+struct KernelVecType {
+ using load_vec_type = vec_op::FP32Vec16;
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+
+#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
+template <>
+struct KernelVecType {
+ using load_vec_type = vec_op::BF16Vec16;
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+#endif
+
+template <>
+struct KernelVecType {
+#if defined(__powerpc64__) || defined(__s390x__)
+ // Power architecture-specific vector type
+ using load_vec_type = vec_op::FP32Vec16;
+#else
+ // Fallback for other architectures
+ using load_vec_type = vec_op::FP16Vec16;
+#endif
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+
+template
+void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+ const float* scale, const int32_t* azp,
+ const int64_t num_tokens,
+ const int64_t input_stride,
+ const int64_t hidden_size) {
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ constexpr float i8_min =
+ static_cast(std::numeric_limits::min());
+ constexpr float i8_max =
+ static_cast(std::numeric_limits::max());
+ const cvt_vec_t inv_scale(1.0 / *scale);
+ const cvt_vec_t i8_min_vec(i8_min);
+ const cvt_vec_t i8_max_vec(i8_max);
+
+ cvt_vec_t zp_vec;
+ if constexpr (AZP) {
+ zp_vec = cvt_vec_t(static_cast(*azp));
+ }
+
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ int8_t* output_ptr = output + i * hidden_size;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = elems_fp32 * inv_scale;
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + zp_vec;
+ }
+
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j);
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = elems_fp32 * inv_scale;
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + zp_vec;
+ }
+
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j, hidden_size - j);
+ }
+}
+
+template
+void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+ float* scale, int32_t* azp,
+ const int64_t num_tokens,
+ const int64_t input_stride,
+ const int64_t hidden_size) {
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ constexpr float i8_min =
+ static_cast(std::numeric_limits::min());
+ constexpr float i8_max =
+ static_cast(std::numeric_limits::max());
+ const cvt_vec_t i8_min_vec(i8_min);
+ const cvt_vec_t i8_max_vec(i8_max);
+
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ cvt_vec_t max_value(std::numeric_limits::lowest());
+ cvt_vec_t min_value(std::numeric_limits::max());
+ {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32);
+ min_value = min_value.min(elems_fp32);
+ } else {
+ max_value = max_value.max(elems_fp32.abs());
+ }
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+
+ if (j + vec_elem_num == hidden_size) {
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32);
+ min_value = min_value.min(elems_fp32);
+ } else {
+ max_value = max_value.max(elems_fp32.abs());
+ }
+ } else {
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32, hidden_size - j);
+ min_value = min_value.min(elems_fp32, hidden_size - j);
+ } else {
+ max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
+ }
+ }
+ }
+
+ float scale_val, azp_val;
+ if constexpr (AZP) {
+ float max_scalar = max_value.reduce_max();
+ float min_scalar = min_value.reduce_min();
+ scale_val = (max_scalar - min_scalar) / 255.0f;
+ azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
+ azp[i] = azp_val;
+ scale[i] = scale_val;
+ } else {
+ scale_val = max_value.reduce_max() / 127.0f;
+ scale[i] = scale_val;
+ }
+
+ const cvt_vec_t inv_scale(1.0 / scale_val);
+ const cvt_vec_t azp_vec(azp_val);
+
+ {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ int8_t* output_ptr = output + i * hidden_size;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = (elems_fp32 * inv_scale);
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + azp_vec;
+ }
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j);
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = (elems_fp32 * inv_scale);
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + azp_vec;
+ }
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j, hidden_size - j);
+ }
+ }
+}
+
+template
+void dynamic_quant_epilogue(const float* input, scalar_t* output,
+ const float* a_scale, const int32_t* azp,
+ const float* azp_adj, const scalar_t* bias,
+ const int64_t num_tokens,
+ const int64_t hidden_size) {
+ CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ const int64_t thread_num = omp_get_max_threads();
+ if (num_tokens > thread_num) {
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ const float* input_ptr = input + i * hidden_size;
+ scalar_t* output_ptr = output + i * hidden_size;
+ int64_t j = 0;
+ cvt_vec_t token_scale_vec(a_scale[i]);
+ cvt_vec_t token_zp_scale_vec;
+ if constexpr (AZP) {
+ float zp_scale_val = a_scale[i] * static_cast(azp[i]);
+ token_zp_scale_vec = cvt_vec_t(zp_scale_val);
+ }
+ for (; j < hidden_size - vec_elem_num; ++j) {
+ cvt_vec_t elems_fp32(input_ptr + j);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + j);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + j);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + j);
+ }
+ cvt_vec_t elems_fp32(input_ptr + j);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + j);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + j);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + j, hidden_size - j);
+ }
+ } else {
+ const int64_t vec_iteration =
+ (hidden_size + vec_elem_num - 1) / vec_elem_num;
+ const int64_t vec_iteration_per_thread =
+ (vec_iteration + thread_num - 1) / thread_num;
+ const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num;
+#pragma omp parallel for schedule(static, 1)
+ for (int64_t i = 0; i < thread_num; ++i) {
+ const int64_t start = elem_num_per_thread * i;
+ const int64_t end = std::min(hidden_size, elem_num_per_thread + start);
+ for (int64_t j = 0; j < num_tokens; ++j) {
+ cvt_vec_t token_scale_vec(a_scale[j]);
+ cvt_vec_t token_zp_scale_vec;
+ if constexpr (AZP) {
+ float zp_scale_val = a_scale[j] * static_cast(azp[j]);
+ token_zp_scale_vec = cvt_vec_t(zp_scale_val);
+ }
+ int64_t k = start;
+ const float* input_ptr = input + j * hidden_size;
+ scalar_t* output_ptr = output + j * hidden_size;
+ for (; k < end - vec_elem_num; k += vec_elem_num) {
+ cvt_vec_t elems_fp32(input_ptr + k);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + k);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + k);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + k);
+ }
+ if (k < end) {
+ cvt_vec_t elems_fp32(input_ptr + k);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + k);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + k);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + k, end - k);
+ }
+ }
+ }
+ }
+}
+} // namespace
+
+int64_t create_onednn_scaled_mm_handler(
+ const torch::Tensor& b, // [IC, OC], column-major
+ const torch::Tensor& b_scales, // [1] or [OC]
+ at::ScalarType output_type, bool dynamic_act_quant, bool use_azp,
+ int64_t primitive_cache_size) {
+ TORCH_CHECK(b.dim() == 2);
+ TORCH_CHECK(b.stride(0) == 1); // Column-major
+ TORCH_CHECK(b_scales.is_contiguous());
+
+ W8A8MatMulPrimitiveHandler::Args args;
+ args.primitive_cache_size = primitive_cache_size;
+
+ if (b_scales.numel() == 1) {
+ args.b_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
+ } else {
+ TORCH_CHECK_EQ(b_scales.numel(), b.size(1));
+ args.b_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL;
+ }
+ args.b_scales_ptr = b_scales.data_ptr();
+ args.b_k_size = b.size(0);
+ args.b_k_stride = b.stride(0);
+ args.b_n_size = b.size(1);
+ args.b_n_stride = b.stride(1);
+ args.b_ptr = b.data_ptr();
+
+ if (dynamic_act_quant) {
+ // dynamic per-token, bias, A scales and A zps will be applied in outside.
+ args.a_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN;
+ args.use_a_zero_point = false;
+ } else {
+ // static per-tensor
+ args.a_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
+ args.use_a_zero_point = use_azp;
+ }
+
+ VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler",
+ [&] {
+ if (dynamic_act_quant) {
+ args.c_type = get_dnnl_type();
+ } else {
+ args.c_type = get_dnnl_type();
+ }
+ });
+
+ return reinterpret_cast(new W8A8MatMulPrimitiveHandler(args));
+}
+
+void onednn_scaled_mm(
+ torch::Tensor& c, // [M, OC], row-major
+ const torch::Tensor& a, // [M, IC], row-major
+ const torch::Tensor& a_scales, // [M] or [1]
+ const std::optional& azp, // [M] or [1]
+ const std::optional& azp_adj, // [M] or [1]
+ const std::optional& bias, // [N]
+ int64_t handler) {
+ CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
+ TORCH_CHECK(a.dim() == 2);
+ TORCH_CHECK(a.is_contiguous());
+ TORCH_CHECK(c.is_contiguous());
+ W8A8MatMulPrimitiveHandler* ptr =
+ reinterpret_cast(handler);
+ const int32_t* azp_ptr = nullptr;
+ if (azp.has_value()) {
+ azp_ptr = azp->data_ptr();
+ }
+ if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
+ TORCH_CHECK_EQ(a_scales.numel(), 1);
+ }
+
+ W8A8MatMulPrimitiveHandler::ExecArgs exec_args;
+ exec_args.a_ptr = a.data_ptr();
+ exec_args.a_m_size = a.size(0);
+ exec_args.bias_ptr = nullptr;
+ exec_args.use_bias = false;
+ exec_args.a_scales_ptr = nullptr;
+ exec_args.a_zero_points_ptr = nullptr;
+
+ VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] {
+ if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
+ if (bias.has_value()) {
+ exec_args.bias_ptr = bias->data_ptr();
+ exec_args.bias_type = get_dnnl_type();
+ exec_args.use_bias = true;
+ }
+ exec_args.a_scales_ptr = a_scales.data_ptr();
+ exec_args.a_zero_points_ptr = azp_ptr;
+ exec_args.c_ptr = c.data_ptr();
+ ptr->execute(exec_args);
+ } else if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) {
+ torch::Tensor tmp_fp32_out =
+ torch::empty_like(c, ::at::ScalarType::Float);
+ exec_args.c_ptr = tmp_fp32_out.data_ptr();
+ ptr->execute(exec_args);
+ if (bias.has_value()) {
+ if (azp.has_value()) {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(),
+ bias->data_ptr(), c.size(0), c.size(1));
+ } else {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, nullptr,
+ bias->data_ptr(), c.size(0), c.size(1));
+ }
+ } else {
+ if (azp.has_value()) {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(),
+ (scalar_t*)nullptr, c.size(0), c.size(1));
+ } else {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, nullptr, (scalar_t*)nullptr,
+ c.size(0), c.size(1));
+ }
+ }
+ } else {
+ TORCH_CHECK(false, "invalid act quant type.");
+ }
+ });
+}
+
+// static-per-tensor quantization.
+void static_scaled_int8_quant(
+ torch::Tensor& out, // [batch, hidden_size]
+ const torch::Tensor& input, // [batch, hidden_size]
+ const torch::Tensor& scale, std::optional const& azp) {
+ CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
+ TORCH_CHECK(out.is_contiguous());
+ TORCH_CHECK_EQ(input.dim(), 2);
+ TORCH_CHECK_EQ(input.stride(1), 1);
+ TORCH_CHECK(scale.numel() == 1);
+ TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
+
+ const int64_t stride = input.stride(0);
+ const int64_t hidden_size = input.size(1);
+ const int64_t num_tokens = input.size(0);
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
+ if (azp.has_value()) {
+ static_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), azp->data_ptr(), num_tokens,
+ stride, hidden_size);
+ } else {
+ static_scaled_int8_quant_impl(input.data_ptr(),
+ out.data_ptr(),
+ scale.data_ptr(), nullptr,
+ num_tokens, stride, hidden_size);
+ }
+ });
+}
+
+// dynamic-per-token quantization.
+void dynamic_scaled_int8_quant(
+ torch::Tensor& out, // [batch, hidden_size]
+ const torch::Tensor& input, // [batch, hidden_size]
+ torch::Tensor& scale, // [batch, 1]
+ std::optional const& azp) {
+ CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
+ TORCH_CHECK(out.is_contiguous());
+ TORCH_CHECK_EQ(input.dim(), 2);
+ TORCH_CHECK_EQ(input.stride(1), 1);
+
+ const int64_t hidden_size = input.size(1);
+ const int64_t num_tokens = input.size(0);
+ const int64_t stride = input.stride(0);
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
+ if (azp.has_value()) {
+ dynamic_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), azp->data_ptr(), num_tokens,
+ stride, hidden_size);
+ } else {
+ dynamic_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), nullptr, num_tokens, stride,
+ hidden_size);
+ }
+ });
+}
diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp
deleted file mode 100644
index 6e120b8d20a7e..0000000000000
--- a/csrc/cpu/quant.cpp
+++ /dev/null
@@ -1,951 +0,0 @@
-#include "cpu_types.hpp"
-#include "dnnl_helper.hpp"
-
-namespace {
-template
-struct KernelVecType {
- using load_vec_type = void;
- using azp_adj_load_vec_type = void;
- using cvt_vec_type = void;
-};
-
-template <>
-struct KernelVecType {
- using load_vec_type = vec_op::FP32Vec16;
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-
-#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
-template <>
-struct KernelVecType {
- using load_vec_type = vec_op::BF16Vec16;
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-#endif
-
-template <>
-struct KernelVecType {
-#if defined(__powerpc64__) || defined(__s390x__)
- // Power architecture-specific vector type
- using load_vec_type = vec_op::FP32Vec16;
-#else
- // Fallback for other architectures
- using load_vec_type = vec_op::FP16Vec16;
-#endif
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-
-#if defined(__AVX512F__) || defined(__aarch64__)
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t inv_scale(1.0 / *scale);
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- cvt_vec_t zp_vec;
- if constexpr (AZP) {
- zp_vec = cvt_vec_t(static_cast(*azp));
- }
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t max_value(std::numeric_limits::lowest());
- cvt_vec_t min_value(std::numeric_limits::max());
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
-
- if (j + vec_elem_num == hidden_size) {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- } else {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32, hidden_size - j);
- min_value = min_value.min(elems_fp32, hidden_size - j);
- } else {
- max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
- }
- }
- }
-
- float scale_val, azp_val;
- if constexpr (AZP) {
- float max_scalar = max_value.reduce_max();
- float min_scalar = min_value.reduce_min();
- scale_val = (max_scalar - min_scalar) / 255.0f;
- azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
- azp[i] = static_cast(azp_val);
- scale[i] = scale_val;
- } else {
- scale_val = max_value.reduce_max() / 127.0f;
- scale[i] = scale_val;
- }
-
- const cvt_vec_t inv_scale(1.0 / scale_val);
- const cvt_vec_t azp_vec(azp_val);
-
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
- }
-}
-
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t a_scale_vec(a_scale);
- cvt_vec_t b_scale_vec(*b_scale);
- cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
-
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- cvt_vec_t token_scale_vec(a_scale[i]);
- cvt_vec_t token_zp_scale_vec;
- if constexpr (AZP) {
- float zp_scale_val = a_scale[i] * static_cast(azp[i]);
- if constexpr (!PerChannel) {
- zp_scale_val *= *b_scale;
- }
- token_zp_scale_vec = cvt_vec_t(zp_scale_val);
- }
-
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-#elif defined(__powerpc64__)
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
-
- const cvt_vec_t inv_scale(1.0 / *scale);
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- cvt_vec_t zp_vec;
- if constexpr (AZP) {
- zp_vec = cvt_vec_t(static_cast(*azp));
- }
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t max_value(std::numeric_limits::lowest());
- cvt_vec_t min_value(std::numeric_limits::max());
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
-
- if (j + vec_elem_num == hidden_size) {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- } else {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32, hidden_size - j);
- min_value = min_value.min(elems_fp32, hidden_size - j);
- } else {
- max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
- }
- }
- }
-
- float scale_val, azp_val;
- if constexpr (AZP) {
- float max_scalar = max_value.reduce_max();
- float min_scalar = min_value.reduce_min();
- scale_val = (max_scalar - min_scalar) / 255.0f;
- azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
- azp[i] = static_cast(azp_val);
- scale[i] = scale_val;
- } else {
- scale_val = max_value.reduce_max() / 127.0f;
- scale[i] = scale_val;
- }
-
- const cvt_vec_t inv_scale(1.0 / scale_val);
- const cvt_vec_t azp_vec(azp_val);
-
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
- }
-}
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t a_scale_vec(a_scale);
- cvt_vec_t b_scale_vec(*b_scale);
- cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
-
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- cvt_vec_t token_scale_vec(a_scale[i]);
- cvt_vec_t token_zp_scale_vec;
- if constexpr (AZP) {
- float zp_scale_val = a_scale[i] * static_cast(azp[i]);
- if constexpr (!PerChannel) {
- zp_scale_val *= *b_scale;
- }
- token_zp_scale_vec = cvt_vec_t(zp_scale_val);
- }
-
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-#else
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(false,
- "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
- "support.")
-}
-
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(false,
- "dynamic_scaled_int8_quant_impl requires "
- "AVX512/powerpc64/AArch64 support.")
-}
-
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(
- false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
-}
-
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_with_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(
- false,
- "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
-}
-#endif
-} // namespace
-
-void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
- const torch::Tensor& a, // [M, IC], row-major
- const torch::Tensor& b, // [IC, OC], column-major
- const torch::Tensor& a_scales, // [1] or [M]
- const torch::Tensor& b_scales, // [1] or [OC]
- const std::optional& bias // [OC]
-) {
- CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
- // Checks for conformality
- TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
- "int8_scaled_mm only supports INT8 inputs.")
- TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
- TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
- b.size(1) == c.size(1));
- TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
- TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
-
- // Check for strides and alignment
- TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
- TORCH_CHECK(b.stride(0) == 1); // Column-major
- TORCH_CHECK(c.stride(0) % 16 == 0 &&
- b.stride(1) % 16 == 0); // 16 Byte Alignment
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
-
- if (bias) {
- TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
- bias->dim() == 1);
- }
-
- VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
- if (a_scales.numel() != 1) {
- // per-token
- // Note: oneDNN doesn't support per-token activation quantization
- // Ideally we want to fuse the GEMM and the scale procedure with oneDNN
- // JIT, the intermediate data is cached in registers or L1. But for now
- // the oneDNN GEMM code generation only supports two quantization
- // patterns: per-tensor or per-output-channel of weight.
- // So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
- // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
- // GEMM, then the per-token scale (and bias) is applied with the epilogue
- // C=s_a * C_inter + bias.
- torch::Tensor tmp_fp32_out =
- torch::empty_like(c, ::at::ScalarType::Float);
- // Compute C_inter=s_b * (A@B)
- DNNLPrimitiveHelper::gemm_s8s8_jit(
- a.data_ptr(), b.data_ptr(),
- tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1),
- a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel());
- if (bias.has_value()) {
- // Compute C=s_a * C_inter + bias
- dynamic_quant_epilogue(
- tmp_fp32_out.data_ptr(), c.data_ptr(),
- a_scales.data_ptr(), nullptr, nullptr, nullptr,
- bias->data_ptr(), c.size(0), c.size(1));
- } else {
- // Compute C=s_a * C_inter
- dynamic_quant_epilogue(
- tmp_fp32_out.data_ptr(), c.data_ptr(),
- a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr,
- c.size(0), c.size(1));
- }
- } else {
- // per-tensor
- if (bias.has_value()) {
- // Compute C=s_a * s_b * (A@B) + bias
- DNNLPrimitiveHelper::gemm_s8s8_jit(
- a.data_ptr