diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 7045d8810493e..bbed80ebe8476 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -8,7 +8,8 @@ template = """

Links for vLLM

- {wheel}
+ {x86_wheel}
+ {arm_wheel}
""" @@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel) with open("index.html", "w") as f: print(f"Generated index.html for {args.wheel}") + # sync the abi tag with .buildkite/scripts/upload-wheels.sh + if "x86_64" in filename: + x86_wheel = filename + arm_wheel = filename.replace("x86_64", "aarch64").replace( + "manylinux1", "manylinux2014" + ) + elif "aarch64" in filename: + x86_wheel = filename.replace("aarch64", "x86_64").replace( + "manylinux2014", "manylinux1" + ) + arm_wheel = filename + else: + raise ValueError(f"Unsupported wheel: {filename}") # cloudfront requires escaping the '+' character f.write( - template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) + template.format( + x86_wheel=x86_wheel, + x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), + arm_wheel=arm_wheel, + arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), + ) ) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml deleted file mode 100644 index 56ec933c9cc0e..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# For vllm script, with -t option (tensor parallel size). -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 -model_name: "HandH1998/QQQ-Llama-3-8b-g128" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.419 - - name: "exact_match,flexible-extract" - value: 0.416 -limit: 1000 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 27a1a9a82bd35..37eeac85c933b 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh index a67fc89d54e60..897f84d1e360d 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh @@ -2,7 +2,7 @@ # We can use this script to compute baseline accuracy on GSM for transformers. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index b98d42aa7b822..792f355c47a51 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -3,7 +3,7 @@ # We use this for fp8, which HF does not support. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py index 12c4ba6aa69a6..50431d0cd4c5e 100644 --- a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py +++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py @@ -3,44 +3,129 @@ import argparse import json import os +from importlib import util import pandas as pd +plotly_found = util.find_spec("plotly.express") is not None + def compare_data_columns( files, name_column, data_column, info_cols, drop_column, debug=False ): - print("\ncompare_data_column: " + data_column) + """ + Align concatenation by keys derived from info_cols instead of row order. + - Pick one canonical key list: subset of info_cols present in ALL files. + - For each file: set index to those keys, aggregate duplicates + - (mean for metric, first for names). + - Concat along axis=1 (indexes align), then reset_index so callers can + - group by columns. + - If --debug, add a _name column per file. + """ + print("\ncompare_data_column:", data_column) + frames = [] raw_data_cols = [] compare_frames = [] + + # 1) choose a canonical key list from info_cols that exists in ALL files + cols_per_file = [] + for f in files: + try: + df_tmp = pd.read_json(f, orient="records") + except Exception as err: + raise ValueError(f"Failed to read {f}") from err + cols_per_file.append(set(df_tmp.columns)) + + key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)] + if not key_cols: + # soft fallback: use any info_cols present in the first file + key_cols = [c for c in info_cols if c in list(cols_per_file[0])] + if not key_cols: + raise ValueError( + "No common key columns found from info_cols across the input files." + ) + + # 2) build a single "meta" block (keys as columns) once, aligned by the key index + meta_added = False + for file in files: - data_df = pd.read_json(file) - serving_df = data_df.dropna(subset=[drop_column], ignore_index=True) - # Show all info columns in the first couple columns - if not frames: - for col in info_cols: - if col not in serving_df.columns: - print(f"Skipping missing column: {col}") - continue - frames.append(serving_df[col]) - # only show test name under debug mode - if debug is True: - serving_df = serving_df.rename(columns={name_column: file + "_name"}) - frames.append(serving_df[file + "_name"]) + df = pd.read_json(file, orient="records") - file = "/".join(file.split("/")[:-1]) - serving_df = serving_df.rename(columns={data_column: file}) - frames.append(serving_df[file]) - raw_data_cols.append(file) - compare_frames.append(serving_df[file]) + # Keep rows that actually have the compared metric (same as original behavior) + if drop_column in df.columns: + df = df.dropna(subset=[drop_column], ignore_index=True) + + # Stabilize numeric key columns (harmless if missing) + for c in ( + "Input Len", + "Output Len", + "TP Size", + "PP Size", + "# of max concurrency.", + "qps", + ): + if c in df.columns: + df[c] = pd.to_numeric(df[c], errors="coerce") + + # Ensure all key columns exist + for c in key_cols: + if c not in df.columns: + df[c] = pd.NA + + # Set index = key_cols and aggregate duplicates → unique MultiIndex + df_idx = df.set_index(key_cols, drop=False) + + # meta (key columns), unique per key + meta = df_idx[key_cols] + if not meta.index.is_unique: + meta = meta.groupby(level=key_cols, dropna=False).first() + + # metric series for this file, aggregated to one row per key + file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file) + s = df_idx[data_column] + if not s.index.is_unique: + s = s.groupby(level=key_cols, dropna=False).mean() + s.name = file_label # column label like original + + # add meta once (from first file) so keys are the leftmost columns + if not meta_added: + frames.append(meta) + meta_added = True + + # (NEW) debug: aligned test-name column per file + if debug and name_column in df_idx.columns: + name_s = df_idx[name_column] + if not name_s.index.is_unique: + name_s = name_s.groupby(level=key_cols, dropna=False).first() + name_s.name = f"{file_label}_name" + frames.append(name_s) + + frames.append(s) + raw_data_cols.append(file_label) + compare_frames.append(s) + + # Generalize ratio: for any file N>=2, add ratio (fileN / file1) if len(compare_frames) >= 2: - # Compare numbers among two files - ratio_df = compare_frames[1] / compare_frames[0] - frames.append(ratio_df) - compare_frames.pop(1) + base = compare_frames[0] + current = compare_frames[-1] + ratio = current / base + ratio = ratio.mask(base == 0) # avoid inf when baseline is 0 + ratio.name = f"Ratio 1 vs {len(compare_frames)}" + frames.append(ratio) + # 4) concat on columns with aligned MultiIndex; + # then reset_index to return keys as columns concat_df = pd.concat(frames, axis=1) + concat_df = concat_df.reset_index(drop=True).reset_index() + if "index" in concat_df.columns: + concat_df = concat_df.drop(columns=["index"]) + + # Ensure key/info columns appear first (in your info_cols order) + front = [c for c in info_cols if c in concat_df.columns] + rest = [c for c in concat_df.columns if c not in front] + concat_df = concat_df[front + rest] + print(raw_data_cols) return concat_df, raw_data_cols @@ -67,6 +152,15 @@ def split_json_by_tp_pp( df = pd.DataFrame(data) + # Keep only "serving" tests + name_col = next( + (c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None + ) + if name_col: + df = df[ + df[name_col].astype(str).str.contains(r"serving", case=False, na=False) + ].copy() + # Handle alias column names rename_map = { "tp_size": "TP Size", @@ -181,7 +275,6 @@ if __name__ == "__main__": f"Expected subset: {filtered_info_cols}, " f"but DataFrame has: {list(output_df.columns)}" ) - output_df_sorted = output_df.sort_values(by=existing_group_cols) output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False) for name, group in output_groups: @@ -189,8 +282,7 @@ if __name__ == "__main__": text_file.write(html_msgs_for_data_cols[i]) text_file.write(html) - if plot is True: - import pandas as pd + if plot and plotly_found: import plotly.express as px df = group[raw_data_cols] diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 85d3e56387421..f96c38bf57db7 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -27,7 +27,12 @@ steps: env: DOCKER_BUILDKIT: "1" + - block: "Build CUDA 12.6 wheel" + key: block-build-cu126-wheel + depends_on: ~ + - label: "Build wheel - CUDA 12.6" + depends_on: block-build-cu126-wheel id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge @@ -68,7 +73,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - label: "Annotate release workflow" diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 57a7bc4e5f5df..9dec9f8e9eb32 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -46,6 +46,11 @@ function cpu_tests() { set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + # Run kernel tests + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -v -s tests/kernels/test_onednn.py" + # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e @@ -99,4 +104,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index deb61a9bafab6..445cd2735c190 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -23,9 +23,13 @@ docker run \ --device /dev/dri \ -v /dev/dri/by-path:/dev/dri/by-path \ --entrypoint="" \ + -e "HF_TOKEN=${HF_TOKEN}" \ + -e "ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK}" \ --name "${container_name}" \ "${image_name}" \ - sh -c ' + bash -c ' + set -e + echo $ZE_AFFINITY_MASK VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp @@ -35,8 +39,8 @@ docker run \ pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py - pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py + pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py + pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py pytest -v -s v1/test_serial_utils.py pytest -v -s v1/test_utils.py pytest -v -s v1/test_metrics_reader.py diff --git a/.buildkite/scripts/tpu/cleanup_docker.sh b/.buildkite/scripts/tpu/cleanup_docker.sh index 209d9c4341cdd..740d81fb39bb0 100755 --- a/.buildkite/scripts/tpu/cleanup_docker.sh +++ b/.buildkite/scripts/tpu/cleanup_docker.sh @@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then # Remove dangling images (those that are not tagged and not used by any container) docker image prune -f # Remove unused volumes / force the system prune for old images as well. - docker volume prune -f && docker system prune --force --filter "until=72h" --all + docker volume prune -f && docker system prune --force --filter "until=24h" --all echo "Docker images and volumes cleanup completed." else echo "Disk usage is below $threshold%. No cleanup needed." diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 037897e53dbef..745f285c008ad 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -14,8 +14,19 @@ fi # Get the single wheel file wheel="${wheel_files[0]}" -# Rename 'linux' to 'manylinux1' in the wheel filename -new_wheel="${wheel/linux/manylinux1}" +# Detect architecture and rename 'linux' to appropriate manylinux version +arch=$(uname -m) +if [[ $arch == "x86_64" ]]; then + manylinux_version="manylinux1" +elif [[ $arch == "aarch64" ]]; then + manylinux_version="manylinux2014" +else + echo "Warning: Unknown architecture $arch, using manylinux1 as default" + manylinux_version="manylinux1" +fi + +# Rename 'linux' to the appropriate manylinux version in the wheel filename +new_wheel="${wheel/linux/$manylinux_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4fc8857854927..df2735fefeedb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -88,15 +88,6 @@ steps: - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py -- label: Chunked Prefill Test - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/basic_correctness/test_chunked_prefill - commands: - - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - label: Core Test # 10min mirror_hardwares: [amdexperimental] fast_check: true @@ -135,7 +126,8 @@ steps: - tests/entrypoints/test_chat_utils commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ + - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py - pytest -v -s entrypoints/test_chat_utils.py - label: Distributed Tests (4 GPUs) # 10min @@ -252,6 +244,7 @@ steps: - pytest -v -s v1/core - pytest -v -s v1/engine - pytest -v -s v1/entrypoints + - pytest -v -s v1/executor - pytest -v -s v1/sample - pytest -v -s v1/logits_processors - pytest -v -s v1/worker @@ -295,15 +288,6 @@ steps: - python3 offline_inference/basic/score.py - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 -- label: Prefix Caching Test # 9min - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/prefix_caching - commands: - - pytest -v -s prefix_caching - - - label: Platform Tests (CUDA) mirror_hardwares: [amdexperimental] source_file_dependencies: @@ -345,6 +329,7 @@ steps: - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py + - pytest -v -s compile/test_decorator.py - label: PyTorch Fullgraph Smoke Test # 9min mirror_hardwares: [amdexperimental] @@ -358,6 +343,7 @@ steps: - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py - pytest -v -s compile/piecewise/test_full_cudagraph.py + - pytest -v -s compile/piecewise/test_multiple_graphs.py - label: PyTorch Fullgraph Test # 18min mirror_hardwares: [amdexperimental] @@ -468,13 +454,11 @@ steps: - label: LM Eval Small Models # 53min mirror_hardwares: [amdexperimental] - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization commands: - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness mirror_hardwares: [amdexperimental] @@ -562,6 +546,15 @@ steps: commands: - pytest -v -s models/language/pooling -m 'not core_model' +- label: Multi-Modal Processor Test + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + - pytest -v -s models/multimodal/processing/test_tensor_schema.py + - label: Multi-Modal Models Test (Standard) mirror_hardwares: [amdexperimental] torch_nightly: true @@ -571,9 +564,7 @@ steps: commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - - pytest -v -s models/multimodal/processing - - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model - - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn" + - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 @@ -584,7 +575,7 @@ steps: - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' + - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing - label: Multi-Modal Models Test (Extended) 2 mirror_hardwares: [amdexperimental] @@ -647,8 +638,10 @@ steps: - vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/fusion.py + - vllm/compilation/fusion_attn.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -663,8 +656,11 @@ steps: - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py + - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern + - pytest -v -s tests/kernels/moe/test_flashinfer.py ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b0dd5e99d4c72..ce9590f02ce71 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,6 +10,7 @@ /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 +/vllm/model_executor/layers/mamba @tdoublep /vllm/multimodal @DarkLight1337 @ywang96 /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee @@ -25,11 +26,11 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat /vllm/v1/structured_output @mgoin @russellb @aarnphm +/vllm/v1/attention/backends/triton_attn.py @tdoublep # Test ownership /.buildkite/lm-eval-harness @mgoin @simon-mo /tests/async_engine @njhill @robertgshaw2-redhat @simon-mo -/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao @@ -44,6 +45,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/v1/structured_output @mgoin @russellb @aarnphm /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee +/tests/models/language/generation/test_hybrid.py @tdoublep # Docs /docs @hmellor @@ -72,3 +74,9 @@ mkdocs.yaml @hmellor /vllm/model_executor/models/pixtral*.py @patrickvonplaten /vllm/transformers_utils/configs/mistral.py @patrickvonplaten /vllm/transformers_utils/tokenizers/mistral.py @patrickvonplaten + +# Kernels +/vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep +/vllm/attention/ops/triton_unified_attention.py @tdoublep + + diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml deleted file mode 100644 index 2b1086b7faf43..0000000000000 --- a/.github/workflows/lint-and-deploy.yaml +++ /dev/null @@ -1,89 +0,0 @@ -name: Lint and Deploy Charts - -on: pull_request - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - lint-and-deploy: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - - name: Set up Helm - uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0 - with: - version: v3.14.4 - - #Python is required because ct lint runs Yamale and yamllint which require Python. - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: '3.13' - - - name: Set up chart-testing - uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0 - with: - version: v3.10.1 - - - name: Run chart-testing (lint) - run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm - - - name: Setup minio - run: | - docker network create vllm-net - docker run -d -p 9000:9000 --name minio --net vllm-net \ - -e "MINIO_ACCESS_KEY=minioadmin" \ - -e "MINIO_SECRET_KEY=minioadmin" \ - -v /tmp/data:/data \ - -v /tmp/config:/root/.minio \ - minio/minio server /data - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - export AWS_EC2_METADATA_DISABLED=true - mkdir opt-125m - cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd .. - aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket - aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive - - - name: Create kind cluster - uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0 - - - name: Build the Docker image vllm cpu - run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env . - - - name: Configuration of docker images, network and namespace for the kind cluster - run: | - docker pull amazon/aws-cli:2.6.4 - kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing - kind load docker-image vllm-cpu-env:latest --name chart-testing - docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")" - kubectl create ns ns-vllm - - - name: Run chart-testing (install) - run: | - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & - helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - - - name: curl test - run: | - kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & - sleep 10 - CODE="$(curl -v -f --location http://localhost:8001/v1/completions \ - --header "Content-Type: application/json" \ - --data '{ - "model": "opt-125m", - "prompt": "San Francisco is a", - "max_tokens": 7, - "temperature": 0 - }'):$CODE" - echo "$CODE" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index bfd02879965ee..0000000000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,111 +0,0 @@ -# This workflow will upload a Python Package to Release asset -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions - -name: Create Release - -on: - push: - tags: - - v* - -# Needed to create release and upload assets -permissions: - contents: write - -jobs: - release: - # Retrieve tag and create release - name: Create Release - runs-on: ubuntu-latest - outputs: - upload_url: ${{ steps.create_release.outputs.upload_url }} - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Extract branch info - shell: bash - run: | - echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV" - - - name: Create Release - id: create_release - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 - env: - RELEASE_TAG: ${{ env.release_tag }} - with: - github-token: "${{ secrets.GITHUB_TOKEN }}" - script: | - const script = require('.github/workflows/scripts/create_release.js') - await script(github, context, core) - - # NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow. - # wheel: - # name: Build Wheel - # runs-on: ${{ matrix.os }} - # needs: release - - # strategy: - # fail-fast: false - # matrix: - # os: ['ubuntu-20.04'] - # python-version: ['3.9', '3.10', '3.11', '3.12'] - # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt. - # cuda-version: ['11.8', '12.1'] - - # steps: - # - name: Checkout - # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - # - name: Setup ccache - # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 - # with: - # create-symlink: true - # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - - # - name: Set up Linux Env - # if: ${{ runner.os == 'Linux' }} - # run: | - # bash -x .github/workflows/scripts/env.sh - - # - name: Set up Python - # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - # with: - # python-version: ${{ matrix.python-version }} - - # - name: Install CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} - - # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} - - # - name: Build wheel - # shell: bash - # env: - # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size - # run: | - # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} - # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) - # asset_name=${wheel_name//"linux"/"manylinux1"} - # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" - # echo "asset_name=${asset_name}" >> "$GITHUB_ENV" - - # - name: Upload Release Asset - # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 - # env: - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # with: - # upload_url: ${{ needs.release.outputs.upload_url }} - # asset_path: ./dist/${{ env.wheel_name }} - # asset_name: ${{ env.asset_name }} - # asset_content_type: application/* - - # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested - # - name: Publish package - # uses: pypa/gh-action-pypi-publish@release/v1.8 - # with: - # repository-url: https://test.pypi.org/legacy/ - # password: ${{ secrets.PYPI_API_TOKEN }} - # skip-existing: true diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 16ae1aadb96be..1ee605dc7bb0d 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -12,16 +12,43 @@ jobs: uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 with: script: | - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + - '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + - 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' + - 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + - 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + - '🚀' - }) + try { + // Get the PR author + const prAuthor = context.payload.pull_request.user.login; + + // Check if this is the author's first PR in this repository + // Use GitHub's search API to find all PRs by this author + const { data: searchResults } = await github.rest.search.issuesAndPullRequests({ + q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`, + per_page: 100 + }); + + const authorPRCount = searchResults.total_count; + + console.log(`Found ${authorPRCount} PRs by ${prAuthor}`); + + // Only post comment if this is the first PR (only one PR by this author) + if (authorPRCount === 1) { + console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + + '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + + 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' + + 'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' + + 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + + 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + + 'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' + + '🚀' + }); + } else { + console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`); + } + } catch (error) { + console.error('Error checking PR history or posting comment:', error); + // Don't fail the workflow, just log the error + } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 34386d670ac76..a1deefb07f09c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,7 +30,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") +set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12", "3.13") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") @@ -357,9 +357,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) set(MARLIN_SRCS - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") diff --git a/benchmarks/README.md b/benchmarks/README.md index 1d715a193ea14..176b40212978f 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -32,6 +32,14 @@ become available.
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
wget http://images.cocodataset.org/zips/train2017.zip + + + ShareGPT4Video (Video) + ✅ + ✅ + + git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video + BurstGPT @@ -194,6 +202,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -230,6 +239,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -244,6 +254,7 @@ vllm bench serve \ ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -609,7 +620,7 @@ vllm bench serve \ --prefix-repetition-prefix-len 512 \ --prefix-repetition-suffix-len 128 \ --prefix-repetition-num-prefixes 5 \ - --prefix-repetition-output-len 128 + --prefix-repetition-output-len 128 ``` @@ -684,4 +695,31 @@ python benchmarks/benchmark_serving.py \ --endpoint /v1/chat/completion ``` +### Videos (ShareGPT4Video) + +Start vLLM: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"video": 1}' \ + --allowed-local-media-path /path/to/sharegpt4video/videos +``` + +Send requests with videos: + +```bash +python benchmarks/benchmark_serving.py \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completion +``` + diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 1559ca2d92841..ba7c733be0b25 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -34,6 +34,7 @@ class RequestFuncInput: multi_modal_content: Optional[dict | list[dict]] = None ignore_eos: bool = False language: Optional[str] = None + request_id: Optional[str] = None @dataclass @@ -71,6 +72,9 @@ async def async_request_tgi( "inputs": request_func_input.prompt, "parameters": params, } + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len if request_func_input.ignore_eos: @@ -82,7 +86,9 @@ async def async_request_tgi( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -145,6 +151,9 @@ async def async_request_trt_llm( } if request_func_input.ignore_eos: payload["min_length"] = request_func_input.output_len + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -152,7 +161,9 @@ async def async_request_trt_llm( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -211,6 +222,8 @@ async def async_request_deepspeed_mii( "top_p": 1.0, } headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -283,6 +296,8 @@ async def async_request_openai_completions( if request_func_input.extra_body: payload.update(request_func_input.extra_body) headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -395,6 +410,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -491,6 +508,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id # Send audio file def to_bytes(y, sr): diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 572292a5aca46..2ea4f9ccaff2b 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -19,6 +19,7 @@ import logging import random from abc import ABC, abstractmethod from collections.abc import Mapping +from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO @@ -54,6 +55,7 @@ class SampleRequest: expected_output_len: int multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None lora_request: Optional[LoRARequest] = None + request_id: Optional[str] = None # ----------------------------------------------------------------------------- @@ -155,7 +157,10 @@ class BenchmarkDataset(ABC): @abstractmethod def sample( - self, tokenizer: PreTrainedTokenizerBase, num_requests: int + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -167,6 +172,7 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. + request_id_prefix (str) The prefix of request_id. Returns: list[SampleRequest]: A list of sample requests generated from the @@ -175,7 +181,10 @@ class BenchmarkDataset(ABC): raise NotImplementedError("sample must be implemented in subclasses.") def maybe_oversample_requests( - self, requests: list[SampleRequest], num_requests: int + self, + requests: list[SampleRequest], + num_requests: int, + request_id_prefix: str = "", ) -> None: """ Oversamples the list of requests if its size is less than the desired @@ -183,11 +192,18 @@ class BenchmarkDataset(ABC): Args: requests (List[SampleRequest]): The current list of sampled - requests. num_requests (int): The target number of requests. + requests. + num_requests (int): The target number of requests. + request_id_prefix (str) The prefix of the request ids. """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, k=num_requests - len(requests)) + additional = deepcopy( + random.choices(requests, k=num_requests - len(requests)) + ) + for i in range(len(additional)): + req = additional[i] + req.request_id = request_id_prefix + str(len(requests) + i) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) @@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]: ) +def process_video(video: Any) -> Mapping[str, Any]: + """ + Process a single video input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key + containing raw video data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(video, dict) and "bytes" in video: + video_bytes = video["bytes"] + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return { + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, + } + + if isinstance(video, str): + video_url = ( + video if video.startswith(("http://", "file://")) else f"file://{video}" + ) + return {"type": "video_url", "video_url": {"url": video_url}} + + raise ValueError( + f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 + ) + + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- @@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset): range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: # Enforce range_ratio < 1 @@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset): prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), + request_id=request_id_prefix + str(i), ) ) + return requests @@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: samples: list = [] + ind = 0 for entry in self.data: if len(samples) >= num_requests: break @@ -430,9 +486,10 @@ class ShareGPTDataset(BenchmarkDataset): skip_min_output_len_check=output_len is not None, ): continue - # TODO: Also support ShareGPT4Video. if image_path := entry.get("image"): mm_content = process_image(image_path) + elif video_path := entry.get("video"): + mm_content = process_video(video_path) else: mm_content = None if enable_multimodal_chat: @@ -444,9 +501,11 @@ class ShareGPTDataset(BenchmarkDataset): expected_output_len=new_output_len, lora_request=lora_request, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(samples, num_requests) + ind += 1 + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -512,10 +571,11 @@ class CustomDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["prompt"] @@ -534,9 +594,12 @@ class CustomDataset(BenchmarkDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -578,6 +641,7 @@ class SonnetDataset(BenchmarkDataset): input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -603,6 +667,7 @@ class SonnetDataset(BenchmarkDataset): prefix_lines = self.data[:num_prefix_lines] samples = [] + ind = 0 while len(samples) < num_requests: extra_lines = random.choices( self.data, k=num_input_lines - num_prefix_lines @@ -613,14 +678,17 @@ class SonnetDataset(BenchmarkDataset): msg, add_generation_prompt=True, tokenize=False ) prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: samples.append( SampleRequest( prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 return samples @@ -672,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset): num_requests: int, max_loras: Optional[int] = None, lora_path: Optional[str] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: samples = [] @@ -693,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset): prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, + request_id=request_id_prefix + str(i), ) ) return samples @@ -752,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] dynamic_output = output_len is None + ind = 0 for item in filtered_data: if len(sampled_requests) >= num_requests: @@ -785,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -814,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) @@ -838,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -870,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - prompt = f"{item['input']}\n\n{item['instruction']} Just output \ - the code, do not include any explanation." + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) # apply template prompt = tokenizer.apply_chat_template( @@ -892,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -924,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["turns"][0] @@ -947,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -974,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] dynamic_output = output_len is None + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -1000,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -1072,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset): "zed-industries/zeta": _format_zeta_prompt, } - def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + **kwargs, + ): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) if formatting_prompt_func is None: raise ValueError(f"Unsupported dataset path: {self.dataset_path}") samples = [] - for sample in self.data: + for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) samples.append( SampleRequest( @@ -1086,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset): expected_output_len=len( tokenizer(sample["expected_output"]).input_ids ), + request_id=request_id_prefix + str(i), ) ) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests) + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -1139,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: import librosa @@ -1148,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset): prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] skipped = 0 + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: break @@ -1166,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 if skipped: logger.warning( "%d samples discarded from dataset due to" @@ -1175,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index ae38caf7290b1..02f5f585c0c16 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -375,11 +375,12 @@ async def benchmark( rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) last_int_rps = current_int_rps - prompt, prompt_len, output_len, mm_content = ( + prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, request.prompt_len, request.expected_output_len, request.multi_modal_data, + request.request_id, ) req_model_id, req_model_name = model_id, model_name if lora_modules: @@ -397,6 +398,7 @@ async def benchmark( multi_modal_content=mm_content, ignore_eos=ignore_eos, extra_body=extra_body, + request_id=request_id, ) task = limited_request_func(request_func_input=request_func_input, pbar=pbar) tasks.append(asyncio.create_task(task)) @@ -665,6 +667,7 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "sonnet": @@ -678,6 +681,7 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=False, + request_id_prefix=args.request_id_prefix, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -690,6 +694,7 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=True, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "hf": @@ -751,6 +756,7 @@ def main(args: argparse.Namespace): num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, + request_id_prefix=args.request_id_prefix, ) else: @@ -762,10 +768,15 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, num_requests=args.num_prompts, output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, ), "burstgpt": lambda: BurstGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path - ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + ), "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -773,6 +784,7 @@ def main(args: argparse.Namespace): input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, ), } @@ -1118,6 +1130,13 @@ def create_argument_parser(): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) + parser.add_argument( + "--request-id-prefix", + type=str, + required=False, + default="benchmark-serving", + help="Specify the prefix of request id.", + ) # group for dataset specific arguments custom_group = parser.add_argument_group("custom dataset options") diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c51b579686529..c7f290e1eb88e 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -597,8 +597,8 @@ def validate_args(args): # https://github.com/vllm-project/vllm/issues/16222 if args.data_parallel_size > 1: raise ValueError( - "Data parallel is not supported in offline benchmark, \ - please use benchmark serving instead" + "Data parallel is not supported in offline benchmark, " + "please use benchmark serving instead" ) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 1d4e730f99ae9..a6b42406b5cb0 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -80,6 +80,11 @@ def bench_run( a, score, topk, renormalize=False ) + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + def run_triton_moe( a: torch.Tensor, w1: torch.Tensor, @@ -111,6 +116,10 @@ def bench_run( w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, @@ -125,6 +134,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -136,6 +149,10 @@ def bench_run( w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): @@ -150,6 +167,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -194,6 +215,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, ) @@ -231,6 +256,10 @@ def bench_run( "w1_scale": w1_scale, "w2_scale": w2_scale, "per_act_token": per_act_token, + "ab_strides1": ab_strides1, + "ab_strides2": ab_strides2, + "c_strides1": c_strides1, + "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -289,6 +318,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, per_act_token, @@ -297,7 +330,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 975d10f2e92ec..a9c4d30d9b189 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 - - if bt.w_ch_s is not None: - s_ch = bt.w_ch_s.to(torch.float32) - else: - s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) - - if bt.w_tok_s is not None: - s_tok = bt.w_tok_s.to(torch.float32) - else: - s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) - - fn = lambda: ops.marlin_qqq_gemm( - a=bt.a, - b_q_weight=w_q, - s_group=w_s, - s_tok=s_tok, - s_ch=s_ch, - workspace=workspace.scratch, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0], - ) + raise NotImplementedError("QQQ is not supported anymore") return fn diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index b4a03665ef10f..752c2d0082167 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -430,7 +430,6 @@ class BenchmarkWorker: hidden_size, topk, dtype_str, - is_marlin=False, ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py new file mode 100644 index 0000000000000..0650cbf3cc18e --- /dev/null +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm, +) +from vllm.platforms import current_platform + + +def benchmark(E, T, H, G=128, runs=50): + current_platform.seed_everything(42) + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") + tokens_per_expert = torch.randint( + T // 2, T, size=(E,), dtype=torch.int32, device="cuda" + ) + + # Warmup + for _ in range(10): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + # Benchmark + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(runs): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + avg_time = (time.perf_counter() - start) / runs * 1000 + + # Calculate actual work done (only count valid tokens) + actual_tokens = tokens_per_expert.sum().item() + actual_elements = actual_tokens * H + + # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops + ops_per_element = 8 + total_ops = actual_elements * ops_per_element + gflops = total_ops / (avg_time / 1000) / 1e9 + + # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) + input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs + output_bytes = actual_tokens * H * 1 # H fp8 outputs + scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 + total_bytes = input_bytes + output_bytes + scale_bytes + memory_bw = total_bytes / (avg_time / 1000) / 1e9 + + return avg_time, gflops, memory_bw + + +configs = [ + (8, 32, 1024), + (16, 64, 2048), + (32, 128, 4096), + # DeepSeekV3 Configs + (256, 16, 7168), + (256, 32, 7168), + (256, 64, 7168), + (256, 128, 7168), + (256, 256, 7168), + (256, 512, 7168), + (256, 1024, 7168), +] + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") +print("-" * 50) + +for E, T, H in configs: + try: + time_ms, gflops, gbps = benchmark(E, T, H) + print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") + except Exception: + print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 77136edca45b5..72b54b40a2d1e 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -3,16 +3,14 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FP8_DTYPE = torch.float8_e4m3fn def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,65 +24,107 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_decode( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") - device = "cuda" torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + NUM_BLOCKS = int(256000 / block_size) - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") - # For decode, batch_size is num_decode_token - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) - kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query - max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) - max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size + kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_seq_len + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 ) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) - k_scale = v_scale = 1.0 + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - - # Benchmark TRT decode - def trt_decode(): - return flashinfer.decode.trtllm_batch_decode_with_kv_cache( - q, - kv_cache, - workspace_buffer, - block_tables, - kv_lens_tensor, - max_kv_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=True, + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + ) def time_fn(fn, warmup=10, trials=20): torch.cuda.synchronize() @@ -101,74 +141,51 @@ def benchmark_decode( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) - # TRT Decode - trt_mean, trt_std = time_fn(trt_decode) - - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size - if kv_last_page_len == 0: - kv_last_page_len = page_size - kv_last_page_lens.append(kv_last_page_len) - - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) - - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), - ) - - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - "NONE", - q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, - ) + o_scale = 1.0 + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) def baseline_decode(): - return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) + return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + + def trtllm_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + out=output_trtllm, + ) baseline_mean, baseline_std = time_fn(baseline_decode) + trtllm_mean, trtllm_std = time_fn(trtllm_decode) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}" f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + ] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="fp8", - ) - all_results.append(result) + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_decode( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 67bd9aebbcca9..49810e20c7d82 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -3,16 +3,14 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FP8_DTYPE = torch.float8_e4m3fn def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,84 +24,99 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_prefill( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + max_q_len = max_kv_len = max_seq_len + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + NUM_BLOCKS = int(256000 / block_size) - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - - q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - q_lens[-1] = MAX_SEQ_LEN - max_q_len = max(q_lens) + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) + q_lens[-1] = max_q_len q_indptr = torch.cat( [ torch.tensor([0], dtype=torch.int32), - torch.cumsum( - torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), ] ) - q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) - kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] - kv_lens[-1] = MAX_SEQ_LEN + query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query - seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_kv_len - max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 ) - - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) - k_scale = v_scale = 1.0 - - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size + num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size + kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = page_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout @@ -115,12 +128,12 @@ def benchmark_prefill( kv_last_page_lens, num_qo_heads, num_kv_heads, - head_dim, - page_size, + head_size, + block_size, causal=True, sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache.dtype, + kv_data_type=dtype, ) def time_fn(fn, warmup=10, trials=20): @@ -138,52 +151,55 @@ def benchmark_prefill( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) - def baseline_prefill(): - return wrapper.run( - q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline - ) + o_scale = 1.0 + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) - def trt_prefill(): + def baseline_prefill(): + return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + + def trtllm_prefill(): return flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, + query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, - seq_lens=seq_lens_tensor, + seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, out=output_trtllm, ) - trt_mean, trt_std = time_fn(trt_prefill) baseline_mean, baseline_std = time_fn(baseline_prefill) + trtllm_mean, trtllm_std = time_fn(trtllm_prefill) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" - f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}" + f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -195,17 +211,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -224,27 +241,41 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_prefill( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + ] + + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_prefill( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md index ae0866ae60751..7adf97bcf5622 100644 --- a/benchmarks/multi_turn/README.md +++ b/benchmarks/multi_turn/README.md @@ -5,11 +5,13 @@ The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `re First start serving your model ```bash -export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ -vllm serve $MODEL_NAME --disable-log-requests +vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests ``` +The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface). + ## Synthetic Multi-Turn Conversations Download the following text file (used for generation of synthetic conversations) @@ -26,10 +28,10 @@ But you may use other text files if you prefer (using this specific file is not Then run the benchmarking script ```bash -export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ -python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \ ---num-clients 2 --max-active-conversations 6 +python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \ +--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6 ``` You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.). diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index 53c3207491d18..d23b7b6e4571d 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -825,9 +825,11 @@ def get_client_config( # Arguments for API requests chat_url = f"{args.url}/v1/chat/completions" + model_name = args.served_model_name if args.served_model_name else args.model + req_args = RequestArgs( chat_url=chat_url, - model=args.model, + model=model_name, stream=not args.no_stream, limit_min_tokens=args.limit_min_tokens, limit_max_tokens=args.limit_max_tokens, @@ -1247,9 +1249,19 @@ async def main() -> None: default=0, help="Seed for random number generators (default: 0)", ) + parser.add_argument( "-m", "--model", type=str, required=True, help="Path of the LLM model" ) + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) + parser.add_argument( "-u", "--url", diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index e0da46e2accaa..cc38cd41a5b24 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -182,17 +182,17 @@ endif() # # Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) # Flag to enable ACL kernels for AARCH64 platforms -if ( VLLM_BUILD_ACL STREQUAL "ON") +if (VLLM_BUILD_ACL STREQUAL "ON") set(USE_ACL ON) else() set(USE_ACL OFF) endif() -if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.8.1 + GIT_TAG v3.9 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) @@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) endif() set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") - endif() + endif() set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_BUILD_DOC "OFF") @@ -217,38 +217,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) -elseif(POWER10_FOUND) - FetchContent_Declare( - oneDNN - GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.7.2 - GIT_PROGRESS TRUE - GIT_SHALLOW TRUE + add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") + target_include_directories( + dnnl_ext + PUBLIC ${oneDNN_SOURCE_DIR}/include + PUBLIC ${oneDNN_BINARY_DIR}/include + PRIVATE ${oneDNN_SOURCE_DIR}/src ) - - set(ONEDNN_LIBRARY_TYPE "STATIC") - set(ONEDNN_BUILD_DOC "OFF") - set(ONEDNN_BUILD_EXAMPLES "OFF") - set(ONEDNN_BUILD_TESTS "OFF") - set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") - set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") - set(ONEDNN_BUILD_GRAPH "OFF") - set(ONEDNN_ENABLE_JIT_PROFILING "OFF") - set(ONEDNN_ENABLE_ITT_TASKS "OFF") - set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") - set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) - - set(DNNL_CPU_RUNTIME "OMP") - - FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) + target_link_libraries(dnnl_ext dnnl) + target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) + list(APPEND LIBS dnnl_ext) + set(USE_ONEDNN ON) +else() + set(USE_ONEDNN OFF) endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") @@ -275,7 +260,6 @@ set(VLLM_EXT_SRC if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) @@ -289,14 +273,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ${VLLM_EXT_SRC}) add_compile_definitions(-DCPU_CAPABILITY_AVX512) endif() -elseif(POWER10_FOUND) - set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" - ${VLLM_EXT_SRC}) endif() -if (ASIMD_FOUND) + +if(USE_ONEDNN) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" + "csrc/cpu/dnnl_kernels.cpp" ${VLLM_EXT_SRC}) endif() diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index ee6768bce26ca..02224cfe3ee81 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1 + GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu) + ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc/include) + ${flashmla_SOURCE_DIR}/csrc) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index e0e95d06290df..6dd6f269f3dc9 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options( // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will // perform worse with larger context length and smaller batch sizes. - num_kv_splits, // split_kv + static_cast(num_kv_splits), // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute @@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba // Assumes device 0 when getting sm_count. arguments.hw_info.sm_count = sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; - arguments.split_kv = num_kv_splits; + arguments.split_kv = static_cast(num_kv_splits); MlaSm100Type::Fmha::set_split_kv(arguments); return MlaSm100Type::Fmha::get_workspace_size(arguments); diff --git a/csrc/cache.h b/csrc/cache.h index 0970b704be3ab..fb0c353b96137 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 131dcb15cd7e9..b3a985c2d5bbb 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { // grid is launched with dimensions (batch, num_splits) -template -__global__ void gather_cache( - const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, +template +__global__ void gather_and_maybe_dequant_cache( + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, // ENTRIES...] scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] @@ -634,6 +634,7 @@ __global__ void gather_cache( const int32_t block_size, const int32_t entry_size, const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const float* __restrict__ scale, const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per // batch @@ -675,10 +676,16 @@ __global__ void gather_cache( if (partial_block_size) full_blocks_end -= 1; } - auto copy_entry = [&](const scalar_t* __restrict__ _src, + auto copy_entry = [&](const cache_t* __restrict__ _src, scalar_t* __restrict__ _dst) { - for (int i = threadIdx.x; i < entry_size; i += blockDim.x) - _dst[i] = _src[i]; + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + _dst[i] = static_cast(_src[i]); + } else { + _dst[i] = + fp8::scaled_convert(_src[i], *scale); + } + } }; for (int pid = split_start; pid < full_blocks_end; ++pid) { @@ -705,25 +712,31 @@ __global__ void gather_cache( } // namespace vllm // Macro to dispatch the kernel based on the data type. -#define CALL_GATHER_CACHE(CPY_DTYPE) \ - vllm::gather_cache<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst.data_ptr()), \ - block_table.data_ptr(), cu_seq_lens.data_ptr(), \ - block_size, entry_size, block_table_stride, cache_block_stride, \ - cache_entry_stride, dst_entry_stride, seq_starts_ptr); +// SCALAR_T is the data type of the destination tensor. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, \ + reinterpret_cast(scale.data_ptr()), seq_starts_ptr); // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence // - Optionally, seq_starts (if provided) offsets the starting block index by // (seq_starts[bid] / page_size) -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, std::optional seq_starts = std::nullopt) { at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -761,20 +774,8 @@ void gather_cache( dim3 grid(batch_size, num_splits); dim3 block(1024); - TORCH_CHECK(src_cache.dtype() == dst.dtype(), - "src_cache and dst must have the same dtype"); - - const int dtype_bits = src_cache.element_size() * 8; const int32_t* seq_starts_ptr = seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; - if (dtype_bits == 32) { - CALL_GATHER_CACHE(uint32_t); - } else if (dtype_bits == 16) { - CALL_GATHER_CACHE(uint16_t); - } else if (dtype_bits == 8) { - CALL_GATHER_CACHE(uint8_t); - } else { - TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); - } + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); } diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 3952c43cbc727..982f7c07a13bd 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec { explicit FP16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec { (__m128i)vec8_data.reg, 1)) {} void save(void* ptr) const { - *reinterpret_cast<__m256i*>(ptr) = reg_low; - *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high; + _mm256_storeu_si256((__m256i*)ptr, reg_low); + _mm256_storeu_si256((__m256i*)ptr + 1, reg_high); } }; #endif diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp new file mode 100644 index 0000000000000..f3f00edb36068 --- /dev/null +++ b/csrc/cpu/dnnl_helper.cpp @@ -0,0 +1,346 @@ +#include +#include + +#include "common/memory_desc.hpp" +#include "common/memory.hpp" + +#include "dnnl_helper.h" + +static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; +} + +static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; +} + +void release_dnnl_matmul_handler(int64_t handler) { + DNNLMatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + delete ptr; +} + +template +class DNNLPrimitiveCache { + public: + using cache_value_t = std::pair; + using result_value_t = VT; + using container_t = std::list; + using value_iterator_t = typename container_t::iterator; + using map_t = std::unordered_map; + using creator_t = VT (*)(); + + public: + DNNLPrimitiveCache(size_t capacity) + : capacity_(capacity), + values_(), + key_to_value_(std::min(256lu, capacity)) { + assert(capacity > 0); + } + + template + result_value_t get_or_create(const KT& key, F&& creator) { + std::optional value = get_value(key); + if (value.has_value()) { + return value.value()->second; + } else { + return add_value({key, creator()})->second; + } + } + + size_t size() const { return values_.size(); } + + private: + void dump_data() { + std::stringstream ss; + ss << "table_id: " << std::hex << reinterpret_cast(this) << std::dec + << "\n"; + ss << "container: ["; + for (auto&& iter : values_) { + ss << "(" << iter.first << ", " << std::hex + << reinterpret_cast(iter.second.get()) << "), " << std::dec; + } + ss << "]\n"; + + ss << "map: ["; + for (auto&& iter : key_to_value_) { + ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex + << reinterpret_cast(iter.second->second.get()) << std::dec + << "), "; + } + ss << "]\n"; + std::printf("%s\n", ss.str().c_str()); + } + + value_iterator_t add_value(cache_value_t&& new_value) { + if (size() == capacity_) { + cache_value_t& last_item = values_.back(); + key_to_value_.erase(last_item.first); + values_.pop_back(); + } + + auto& added_value_ = values_.emplace_front(std::move(new_value)); + key_to_value_.emplace(added_value_.first, values_.begin()); + return values_.begin(); + } + + std::optional get_value(const KT& key) { + if (key_to_value_.size() > 0 && key == values_.begin()->first) { + return values_.begin(); + } + + auto value_map_iterator = key_to_value_.find(key); + if (value_map_iterator != key_to_value_.end()) { + values_.splice(values_.begin(), values_, value_map_iterator->second); + return value_map_iterator->second; + } else { + return {}; + } + } + + private: + const size_t capacity_; + container_t values_; + map_t key_to_value_; +}; + +DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( + const Args& args, dnnl::memory::data_type b_type) + : b_n_size_(args.b_n_size), + b_n_stride_(args.b_n_stride), + b_k_size_(args.b_k_size), + b_k_stride_(args.b_k_stride), + b_type_(b_type), + c_type_(args.c_type), + runtime_memory_ptrs_(8), + primitive_cache_size_(args.primitive_cache_size) { + assert(primitive_cache_size_ > 0); +} + +void DNNLMatMulPrimitiveHandler::prepack_weight( + void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); + dnnl::memory packed_weight(b_target_mem_desc, default_engine()); + { + dnnl::reorder(original_weight, packed_weight) + .execute(default_stream(), original_weight, packed_weight); + default_stream().wait(); + } + memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight; + b_target_mem_desc_ = b_target_mem_desc; +} + +void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr( + size_t index, dnnl_memory* memory_ptr) { + dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage(); + dnnl_memory_desc* mem_desc = const_cast(memory_ptr->md()); + runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc}; +} + +std::pair +DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) { + return runtime_memory_ptrs_[index]; +} + +namespace std { +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^ + hash()(static_cast(val.a_qs)) ^ + hash()(static_cast(val.b_qs)) ^ hash()(val.use_azp) ^ + hash()(static_cast(val.c_type)); + } +}; + +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; +} // namespace std + +bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && + l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp && + l.c_type == r.c_type; +} + +bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size && + l.bias_type == r.bias_type; +} + +static std::shared_ptr +get_w8a8_class_primitive_cache( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), + dnnl::memory::data_type::s8), + use_azp_(args.use_a_zero_point), + a_qs_(args.a_quantization_strategy), + b_qs_(args.b_quantization_strategy), + m_size_cache_(nullptr) { + assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL); + assert(b_qs_ != QuantizationStrategy::PER_TOKEN); + if (a_qs_ == QuantizationStrategy::PER_TOKEN) { + assert(!use_azp_); + }; + prepack_weight(args.b_ptr, + create_primitive_desc( + MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2); + a_scale_storage->set_data_handle((void*)args.a_scales_ptr); + } + if (use_azp_) { + auto&& [a_zero_point_storage, a_zero_point_mem_desc] = + get_runtime_memory_ptr(3); + a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr); + } + + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4); + bias_storage->set_data_handle((void*)args.bias_ptr); + } + + dnnl::matmul matmul = get_matmul_cache(args); + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, + .b_k_size = b_k_size_, + .a_qs = a_qs_, + .b_qs = b_qs_, + .use_azp = use_azp_, + .c_type = c_type_}; + m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_); + } + + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + return dnnl::matmul(desc); + }); +} + +void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get()); + if (use_azp_) { + memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get()); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), + (void*)args.b_scales_ptr); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), (void*)args.b_scales_ptr); + } + + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); +} + +dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md({key.a_m_size, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab); + dnnl::memory::desc b_md; + if (first_time) { + b_md = + dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8, + dnnl::memory::format_tag::any); + } else { + b_md = b_target_mem_desc_; + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_SRC, 0); + if (use_azp_) { + attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + if (key.use_bias) { + // For PER_TOKEN, bias will be applied in epilogue + assert(a_qs_ == QuantizationStrategy::PER_TENSOR); + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h new file mode 100644 index 0000000000000..54ceefced9e98 --- /dev/null +++ b/csrc/cpu/dnnl_helper.h @@ -0,0 +1,169 @@ +#ifndef DNNL_HELPER_H +#define DNNL_HELPER_H + +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace c10 { +struct BFloat16; +struct Half; +} // namespace c10 + +namespace dnnl { +namespace impl { +struct memory_storage_t; +struct matmul_pd_t; +struct matmul_desc_t; +} // namespace impl +} // namespace dnnl +struct dnnl_memory_desc; + +template +class DNNLPrimitiveCache; + +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} + +class DNNLMatMulPrimitiveHandler { + public: + virtual ~DNNLMatMulPrimitiveHandler() = default; + + protected: + struct Args { + dnnl_dim_t b_n_size; + dnnl_dim_t b_n_stride; + dnnl_dim_t b_k_size; + dnnl_dim_t b_k_stride; + void* b_ptr; + dnnl::memory::data_type c_type; + size_t primitive_cache_size; + }; + + protected: + DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); + + void prepack_weight(void* original_b_ptr, + dnnl::memory::desc b_target_mem_desc); + + void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); + + std::pair + get_runtime_memory_ptr(size_t index); + + protected: + const dnnl_dim_t b_n_size_; + const dnnl_dim_t b_n_stride_; + const dnnl_dim_t b_k_size_; + const dnnl_dim_t b_k_stride_; + dnnl::memory::data_type b_type_; + dnnl::memory::data_type c_type_; + std::unordered_map memory_cache_; + std::vector> + runtime_memory_ptrs_; + dnnl::memory::desc b_target_mem_desc_; + int64_t primitive_cache_size_; +}; + +class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL }; + + struct Args : public DNNLMatMulPrimitiveHandler::Args { + bool use_a_zero_point; + QuantizationStrategy a_quantization_strategy; + QuantizationStrategy b_quantization_strategy; + float* b_scales_ptr; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + QuantizationStrategy a_qs; + QuantizationStrategy b_qs; + bool use_azp; + dnnl::memory::data_type c_type; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const int8_t* a_ptr; + const float* a_scales_ptr; + const int32_t* a_zero_points_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + W8A8MatMulPrimitiveHandler(const Args& args); + + QuantizationStrategy get_input_scale_strategy() const { return a_qs_; } + + bool get_input_use_zero_point() const { return use_azp_; } + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + const bool use_azp_; + const QuantizationStrategy a_qs_; + const QuantizationStrategy b_qs_; + std::shared_ptr m_size_cache_; +}; + +#endif diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp deleted file mode 100644 index 1cb8dc5b25a66..0000000000000 --- a/csrc/cpu/dnnl_helper.hpp +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef DNNL_HELPER_HPP -#define DNNL_HELPER_HPP - -#include -#include - -#include "oneapi/dnnl/dnnl.hpp" - -namespace { -template -struct DNNLType { - static constexpr dnnl::memory::data_type type = - dnnl::memory::data_type::undef; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; -}; - -template -constexpr inline dnnl::memory::data_type get_dnnl_type() { - return DNNLType>::type; -} -}; // namespace - -template -class DNNLPrimitiveHelper { - public: - // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) - // A: [M, K], row-major - // B: [K, N], column-major - // C: [M, N], row-major - // bias: [N], row-major, optional - // a_scales: [MS] - // b_scales: [NS] - // Note: Due to the limitation of oneDNN - // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is - // not supported. - - template - static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, - const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, - dnnl_dim_t K, const float* a_scales, - const float* b_scales, dnnl_dim_t MS, - dnnl_dim_t NS) { - auto&& OutputType = get_dnnl_type(); - auto&& BiasType = get_dnnl_type(); - - dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); - dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); - dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); - - dnnl::primitive_attr attr; - if constexpr (!InputNoScale) { - if (MS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_SRC, 0); - } else { - // per-token - TORCH_CHECK(false, "per-token quantization is unsupported."); - } - } - - if (NS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); - } else { - // per-channel - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); - } - - dnnl::matmul::primitive_desc matmul_pd; -// Create memory descriptors with format_tag::any for the primitive. This -// enables the matmul primitive to choose memory layouts for an -// optimized primitive implementation, and these layouts may differ from the -// ones provided by the user. -#ifdef __aarch64__ - auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, - dnnl::memory::format_tag::any); - auto mat_weights_md = dnnl::memory::desc( - {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); - auto mat_dst_md = - dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, - mat_weights_md, bias_md, - mat_dst_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc( - default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); - } -#else - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - bias_md, c_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - c_md, attr); - } -#endif - dnnl::matmul matmul(matmul_pd); - - auto& engine = default_engine(); - - dnnl::memory a_m(a_md, engine, (void*)a); - dnnl::memory b_m(b_md, engine, (void*)b); - dnnl::memory c_m(c_md, engine, (void*)c); - dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)a_scales); - dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)b_scales); - - auto& stream = default_stream(); - - auto mat_src_mem = a_m; - auto mat_weights_mem = b_m; - auto mat_dst_mem = c_m; -#ifdef __aarch64__ - if (matmul_pd.weights_desc() != b_m.get_desc()) { - mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); - dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); - } -#endif - if constexpr (InputNoScale) { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } else { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } - stream.wait(); - } - - private: - static dnnl::engine& default_engine() { - static dnnl::engine engine(dnnl::engine::kind::cpu, 0); - return engine; - } - - static dnnl::stream& default_stream() { - static dnnl::stream stream(default_engine()); - return stream; - } -}; -#endif diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp new file mode 100644 index 0000000000000..acc3b9ecde143 --- /dev/null +++ b/csrc/cpu/dnnl_kernels.cpp @@ -0,0 +1,494 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.h" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; +#endif + +template <> +struct KernelVecType { +#if defined(__powerpc64__) || defined(__s390x__) + // Power architecture-specific vector type + using load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures + using load_vec_type = vec_op::FP16Vec16; +#endif + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = azp_val; + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const int32_t* azp, + const float* azp_adj, const scalar_t* bias, + const int64_t num_tokens, + const int64_t hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + const int64_t thread_num = omp_get_max_threads(); + if (num_tokens > thread_num) { +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + const float* input_ptr = input + i * hidden_size; + scalar_t* output_ptr = output + i * hidden_size; + int64_t j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + for (; j < hidden_size - vec_elem_num; ++j) { + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j); + } + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j, hidden_size - j); + } + } else { + const int64_t vec_iteration = + (hidden_size + vec_elem_num - 1) / vec_elem_num; + const int64_t vec_iteration_per_thread = + (vec_iteration + thread_num - 1) / thread_num; + const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num; +#pragma omp parallel for schedule(static, 1) + for (int64_t i = 0; i < thread_num; ++i) { + const int64_t start = elem_num_per_thread * i; + const int64_t end = std::min(hidden_size, elem_num_per_thread + start); + for (int64_t j = 0; j < num_tokens; ++j) { + cvt_vec_t token_scale_vec(a_scale[j]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[j] * static_cast(azp[j]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + int64_t k = start; + const float* input_ptr = input + j * hidden_size; + scalar_t* output_ptr = output + j * hidden_size; + for (; k < end - vec_elem_num; k += vec_elem_num) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k); + } + if (k < end) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k, end - k); + } + } + } + } +} +} // namespace + +int64_t create_onednn_scaled_mm_handler( + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& b_scales, // [1] or [OC] + at::ScalarType output_type, bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(b_scales.is_contiguous()); + + W8A8MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + if (b_scales.numel() == 1) { + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + } else { + TORCH_CHECK_EQ(b_scales.numel(), b.size(1)); + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL; + } + args.b_scales_ptr = b_scales.data_ptr(); + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + if (dynamic_act_quant) { + // dynamic per-token, bias, A scales and A zps will be applied in outside. + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN; + args.use_a_zero_point = false; + } else { + // static per-tensor + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + args.use_a_zero_point = use_azp; + } + + VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler", + [&] { + if (dynamic_act_quant) { + args.c_type = get_dnnl_type(); + } else { + args.c_type = get_dnnl_type(); + } + }); + + return reinterpret_cast(new W8A8MatMulPrimitiveHandler(args)); +} + +void onednn_scaled_mm( + torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& a_scales, // [M] or [1] + const std::optional& azp, // [M] or [1] + const std::optional& azp_adj, // [M] or [1] + const std::optional& bias, // [N] + int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_scaled_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); + W8A8MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + const int32_t* azp_ptr = nullptr; + if (azp.has_value()) { + azp_ptr = azp->data_ptr(); + } + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + TORCH_CHECK_EQ(a_scales.numel(), 1); + } + + W8A8MatMulPrimitiveHandler::ExecArgs exec_args; + exec_args.a_ptr = a.data_ptr(); + exec_args.a_m_size = a.size(0); + exec_args.bias_ptr = nullptr; + exec_args.use_bias = false; + exec_args.a_scales_ptr = nullptr; + exec_args.a_zero_points_ptr = nullptr; + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] { + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + if (bias.has_value()) { + exec_args.bias_ptr = bias->data_ptr(); + exec_args.bias_type = get_dnnl_type(); + exec_args.use_bias = true; + } + exec_args.a_scales_ptr = a_scales.data_ptr(); + exec_args.a_zero_points_ptr = azp_ptr; + exec_args.c_ptr = c.data_ptr(); + ptr->execute(exec_args); + } else if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) { + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + exec_args.c_ptr = tmp_fp32_out.data_ptr(); + ptr->execute(exec_args); + if (bias.has_value()) { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + (scalar_t*)nullptr, c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, (scalar_t*)nullptr, + c.size(0), c.size(1)); + } + } + } else { + TORCH_CHECK(false, "invalid act quant type."); + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + const torch::Tensor& scale, std::optional const& azp) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); + + const int64_t stride = input.stride(0); + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + static_scaled_int8_quant_impl(input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), nullptr, + num_tokens, stride, hidden_size); + } + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + torch::Tensor& scale, // [batch, 1] + std::optional const& azp) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + const int64_t stride = input.stride(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, stride, + hidden_size); + } + }); +} diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp deleted file mode 100644 index 6e120b8d20a7e..0000000000000 --- a/csrc/cpu/quant.cpp +++ /dev/null @@ -1,951 +0,0 @@ -#include "cpu_types.hpp" -#include "dnnl_helper.hpp" - -namespace { -template -struct KernelVecType { - using load_vec_type = void; - using azp_adj_load_vec_type = void; - using cvt_vec_type = void; -}; - -template <> -struct KernelVecType { - using load_vec_type = vec_op::FP32Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) -template <> -struct KernelVecType { - using load_vec_type = vec_op::BF16Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; -#endif - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power architecture-specific vector type - using load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures - using load_vec_type = vec_op::FP16Vec16; -#endif - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if defined(__AVX512F__) || defined(__aarch64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#elif defined(__powerpc64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#else -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " - "support.") -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "dynamic_scaled_int8_quant_impl requires " - "AVX512/powerpc64/AArch64 support.") -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_with_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, - "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} -#endif -} // namespace - -void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Ideally we want to fuse the GEMM and the scale procedure with oneDNN - // JIT, the intermediate data is cached in registers or L1. But for now - // the oneDNN GEMM code generation only supports two quantization - // patterns: per-tensor or per-output-channel of weight. - // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * - // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN - // GEMM, then the per-token scale (and bias) is applied with the epilogue - // C=s_a * C_inter + bias. - torch::Tensor tmp_fp32_out = - torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - bias->data_ptr(), a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } else { - // Compute C=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - nullptr, a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - } - }); -} - -void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const torch::Tensor& azp_adj, // [OC] - const std::optional& azp, // [1] or [M] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_azp only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); - } - if (azp) { - TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); - } - TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); - - // azp & bias types - TORCH_CHECK(azp_adj.dtype() == torch::kInt32); - TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); - TORCH_CHECK(!bias || bias->dtype() == c.dtype(), - "currently bias dtype must match output dtype ", c.dtype()); - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } - } else { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C_inter=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), bias->data_ptr(), - a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), - b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); - } else { - // Compute C_inter=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - - // Compute C=C_inter - s_a * s_b * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } else { - // Per-Tensor - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } - } - }); -} - -// static-per-tensor quantization. -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale, - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp.has_value() || azp->numel() == 1); - - const int hidden_size = input.size(-1); - const int num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "static_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -// dynamic-per-token quantization. -void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale, // [..., 1] - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_ppc64le only supports INT8 inputs."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - // We dont need this - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - }); -} - -#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b20a054648428..c9f426bdf618a 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -6,25 +6,20 @@ std::string init_cpu_threads_env(const std::string& cpu_ids); -void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); +void release_dnnl_matmul_handler(int64_t handler); -void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const torch::Tensor& azp_adj, - const std::optional& azp, - const std::optional& bias); +int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b, + const torch::Tensor& b_scales, + at::ScalarType output_type, + bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size); -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); -#endif +void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& a_scales, + const std::optional& azp, + const std::optional& azp_adj, + const std::optional& bias, + int64_t handler); void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, @@ -151,8 +146,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); // Quantization -#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) +#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \ + defined(__powerpc64__) at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // Helper function to release oneDNN handlers + ops.def("release_dnnl_matmul_handler(int handler) -> ()", + &release_dnnl_matmul_handler); + + // Create oneDNN W8A8 handler + ops.def( + "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " + "output_type, bool dynamic_act_quant, bool use_azp, int " + "primitive_cache_size) -> int", + &create_onednn_scaled_mm_handler); + + // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization + ops.def( + "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, " + "Tensor? azp_adj, Tensor? bias, int handler) -> ()"); + ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm); // Compute int8 quantized tensor for given scaling factor. ops.def( @@ -168,50 +180,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); -#elif defined(__powerpc64__) - // Compute int8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); - ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); - - // Compute int8 quantized tensor and scaling factor - ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); - ops.impl("dynamic_scaled_int8_quant", torch::kCPU, - &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 2922352a3f7cc..ca0c873f49d9f 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -45,8 +45,6 @@ void moe_permute( auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto permuted_experts_id = torch::empty_like(topk_ids); auto sorted_row_idx = torch::empty_like(inv_permuted_idx); - auto align_expert_first_token_offset = - torch::zeros_like(expert_first_token_offset); CubKeyValueSorter sorter{}; int64_t* valid_num_ptr = nullptr; @@ -85,12 +83,14 @@ void moe_permute( }); // get m_indices and update expert_first_token_offset with align block - getMIndices(get_ptr(expert_first_token_offset), - get_ptr(align_expert_first_token_offset), - get_ptr(m_indices), n_local_expert, align_block_size_value, - stream); + // this is only required for DeepGemm and not required for CUTLASS group gemm if (align_block_size.has_value()) { - // update align_expert_first_token_offset + auto align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); expert_first_token_offset.copy_(align_expert_first_token_offset); } } @@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, torch::Tensor& expert_first_token_offset, torch::Tensor& src_row_id2dst_row_id_map, torch::Tensor& m_indices) { - TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); + TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); } -void moe_unpermute(const torch::Tensor& input, - const torch::Tensor& topk_weights, torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indices, - const std::optional& expert_map, - int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, - torch::Tensor& permuted_input, - torch::Tensor& expert_first_token_offset, - torch::Tensor& src_row_id2dst_row_id_map, - torch::Tensor& m_indices) { +void moe_unpermute( + const torch::Tensor& permuted_hidden_states, + const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx, + const std::optional& expert_first_token_offset, int64_t topk, + torch::Tensor& hidden_states) { TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); } @@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} +} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 64bcec6ca1527..86fe848e2fd5a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 6c6e89790847f..15bb2c300543c 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -10,7 +10,7 @@ template __global__ void get_group_gemm_starts( - int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets, ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, ElementAB* b_base_as_int, ElementC* out_base_as_int, @@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts( else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ get_group_gemm_starts \ <<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ static_cast(a_ptrs.data_ptr()), \ static_cast(b_ptrs.data_ptr()), \ static_cast(out_ptrs.data_ptr()), \ @@ -61,6 +61,8 @@ void run_get_group_gemm_starts( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + // expect int64_t to avoid overflow during offset calculations + TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); int num_experts = static_cast(expert_offsets.size(0)); bool per_act_token = a_scales.numel() != 1; diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 100f485084444..49cafcc32adc6 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, } } +namespace { +inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& atomic_buffer, + int64_t num_experts, int64_t n, + int64_t k, cudaStream_t stream, + const bool swap_ab) { + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + + const int32_t* topk_ptr = static_cast(topk_ids.data_ptr()); + int32_t* ps1_ptr = static_cast(problem_sizes1.data_ptr()); + int32_t* ps2_ptr = static_cast(problem_sizes2.data_ptr()); + int32_t* atomic_ptr = static_cast(atomic_buffer.data_ptr()); + + if (swap_ab) { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } else { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } +} +} // namespace + +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + // Swap-AB should be disabled for FP4 path + bool may_swap_ab = (!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD); + + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); +} + void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller( bool may_swap_ab = (!blockscale_offsets.has_value()) && (topk_ids.numel() <= SWAP_AB_THRESHOLD); - if (may_swap_ab) { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } else { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); if (blockscale_offsets.has_value()) { // fp4 path diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 106bacb4883cb..84843ee6e0949 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data( version_num, ". Required capability: 90 or 100"); } +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + int32_t version_num = get_sm_version_num(); +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) + get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, + problem_sizes2, num_experts, n, k, + blockscale_offsets); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " + "kernel for CUDA device capability: ", + version_num, ". Required capability: 90 or 100"); +} + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 9af7833d09f32..0d14ba15937c6 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -349,9 +349,12 @@ def to_cute_constant(value: list[int]): def unique_schedules(impl_configs: list[ImplConfig]): - return list( - set(sch for impl_config in impl_configs - for sch in impl_config.schedules)) + # Use dict over set for deterministic ordering + return list({ + sch: None + for impl_config in impl_configs + for sch in impl_config.schedules + }.keys()) def unsigned_type_with_bitwidth(num_bits): @@ -568,78 +571,79 @@ def generate(): itertools.repeat(default_heuristic)) ] - # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) - # TODO (LucasWilkinson): Further tuning required - qqq_tile_heuristic_config = { - #### M = 257+ - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), - # "M > 256": ((128, 256), (2, 1, 1)), - "M > 256": ((128, 128), (2, 1, 1)), - #### M = 129-256 - "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), - "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 128": ((128, 256), (2, 1, 1)), - "M > 128": ((128, 128), (2, 1, 1)), - #### M = 65-128 - "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), - "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), - "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), - "M > 64": ((128, 128), (2, 1, 1)), - #### M = 33-64 - "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), - # Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), - "M > 32": ((128, 64), (2, 1, 1)), - #### M = 17-32 - "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), - "M > 16": ((256, 32), (2, 1, 1)), - #### M = 1-16 - "N >= 26624": ((256, 16), (1, 1, 1)), - None: ((128, 16), (1, 1, 1)), - } + # TODO: Support W4A8 when ready + # # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # # TODO (LucasWilkinson): Further tuning required + # qqq_tile_heuristic_config = { + # #### M = 257+ + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # # "M > 256": ((128, 256), (2, 1, 1)), + # "M > 256": ((128, 128), (2, 1, 1)), + # #### M = 129-256 + # "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + # "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 128": ((128, 256), (2, 1, 1)), + # "M > 128": ((128, 128), (2, 1, 1)), + # #### M = 65-128 + # "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + # "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + # "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + # "M > 64": ((128, 128), (2, 1, 1)), + # #### M = 33-64 + # "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # # Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + # "M > 32": ((128, 64), (2, 1, 1)), + # #### M = 17-32 + # "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + # "M > 16": ((256, 32), (2, 1, 1)), + # #### M = 1-16 + # "N >= 26624": ((256, 16), (1, 1, 1)), + # None: ((128, 16), (1, 1, 1)), + # } - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - qqq_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore - for cond, tile_config in qqq_tile_heuristic_config.items() - ] + # # For now we use the same heuristic for all types + # # Heuristic is currently tuned for H100s + # qqq_heuristic = [ + # (cond, ScheduleConfig(*tile_config, + # **sch_common_params)) # type: ignore + # for cond, tile_config in qqq_tile_heuristic_config.items() + # ] - QQQ_kernel_types = [ - *(TypeConfig( - a=DataType.s8, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.s32, - ) for b_group_scale in (DataType.f16, DataType.void)), - *(TypeConfig( - a=DataType.e4m3, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.f32, - ) for b_group_scale in (DataType.f16, DataType.void)), - ] + # QQQ_kernel_types = [ + # *(TypeConfig( + # a=DataType.s8, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.s32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # *(TypeConfig( + # a=DataType.e4m3, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.f32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # ] - impl_configs += [ - ImplConfig(x[0], x[1], x[2]) - for x in zip(QQQ_kernel_types, - itertools.repeat(get_unique_schedules(qqq_heuristic)), - itertools.repeat(qqq_heuristic)) - ] + # impl_configs += [ + # ImplConfig(x[0], x[1], x[2]) + # for x in zip(QQQ_kernel_types, + # itertools.repeat(get_unique_schedules(qqq_heuristic)), + # itertools.repeat(qqq_heuristic)) + # ] output_dir = os.path.join(SCRIPT_DIR, "generated") diff --git a/csrc/quantization/marlin/dense/LICENSE b/csrc/quantization/marlin/dense/LICENSE deleted file mode 100644 index 1d1e4cf9c8233..0000000000000 --- a/csrc/quantization/marlin/dense/LICENSE +++ /dev/null @@ -1,209 +0,0 @@ -Contains code from https://github.com/IST-DASLab/marlin - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. diff --git a/csrc/quantization/marlin/dense/common/base.h b/csrc/quantization/marlin/dense/common/base.h deleted file mode 100644 index 68c83d5478cf8..0000000000000 --- a/csrc/quantization/marlin/dense/common/base.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; diff --git a/csrc/quantization/marlin/dense/common/mem.h b/csrc/quantization/marlin/dense/common/mem.h deleted file mode 100644 index 64f9c393d77ce..0000000000000 --- a/csrc/quantization/marlin/dense/common/mem.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu deleted file mode 100644 index ea96326ed7e61..0000000000000 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_dense { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - auto s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == - -1) // for per-column quantization we finally apply the scale here - res = __hmul2(res, s[0]); - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - // Uncomment for debug - // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + - // ", thread_n = " + str(th_config.thread_n) + - // ", num_threads = " + str(th_config.num_threads) + " for - // MKN = [" + str(prob_m) + - // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace marlin_dense - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_dense::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_dense::tile_size)); - TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_dense::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_dense::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * - marlin_dense::pack_factor_4bit; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize - if (b_scales.size(0) != 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - } - int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_dense::min_thread_n)); - int min_workspace_size = - (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, marlin_dense::max_par); - - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_gemm", &marlin_gemm); -} diff --git a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu deleted file mode 100644 index c96d68d9b29aa..0000000000000 --- a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu +++ /dev/null @@ -1,1248 +0,0 @@ -/* - * Adapted from - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp - * Modified by HandH1998 - * Copyright (C) 2024 HandH1998 - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "../dense/common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "../dense/common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS_GROUP = Vec; // weight per-group quantization scales -using FragS_CHANNEL = - Vec; // weight per-channel quantization scales or activaton - // per-token quantization scales - -// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, -// cp.async.ca can support BYTES = 4, 8, 16; -// as s_tok's shape is equal to prob_m, we need set s_tok to float type, -// and cp_size = 1 float, i.e., 4 BYTES -// Asynchronous global->shared copy for activation quantizaton scales s_tok -__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 4; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.ca.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// m16n8k16 tensor core mma instruction with int8 inputs and int32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - int* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), - "r"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in int8 tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -inline __device__ half2 float2_to_half2(float2 f) { - uint32_t res; - // NOTE(HandH1998): h0,h1 should be uint16_t, not half - uint16_t h0, h1; - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); - return reinterpret_cast(res); -} - -inline __device__ float int32_to_float(int h) { - float res; - asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); - return res; -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per channel dequant. -__device__ inline FragB dequant_per_channel(int q) { - static constexpr int MASK = 0xf0f0f0f0; - FragB frag_b; - frag_b[0] = (q & MASK); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per group dequant. -__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { - static constexpr uint32_t LO = 0x000f000f; - static constexpr uint32_t HI = 0x00f000f0; - static constexpr uint32_t EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - static constexpr uint32_t SUB = 0x64086408; - static constexpr uint32_t MUL = 0x2c002c00; - static constexpr uint32_t ADD = 0xd480d480; - *reinterpret_cast(&t0) = __hsub2( - *reinterpret_cast(&t0), *reinterpret_cast(&SUB)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - - uint16_t s = reinterpret_cast(&frag_s)[i]; - uint32_t double_s; - // pack 2xfp16 to half2 - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); - // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 - // half, respectively) - static constexpr uint32_t MAGIC_NUM = 0x64806480; - *reinterpret_cast(&t0) = __hfma2( - *reinterpret_cast(&t0), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 - // int8 into 1 uint32 - FragB frag_b; - uint32_t uint8s; - static constexpr uint32_t MASK_0246 = 0x6420; - static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(uint8s) - : "r"(t0), "r"(t1), "n"(MASK_0246)); - frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); - return frag_b; -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if constexpr (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; - D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 16; - C += 16 * thread_m_blocks * prob_n / 4; - D += 16 * thread_m_blocks * prob_n / 8; - s_tok += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 16; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 1 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - constexpr int s_tok_sh_stride = 16 * thread_m_blocks; - - constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4; - - int s_group_gl_stride = prob_n / 8; - constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_group_sh_stage = s_group_sh_stride; - int s_group_gl_rd_delta = s_group_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix - int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); - a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - auto s_tok_gl_rd = threadIdx.x; - // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10, - // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for - // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as - // s_tok's size is not fixed, we can not shuffle before inference we shuffle - // it when fetching s_tok from global memory to shared memory, that's why - // s_tok_sh_wr is like this - int s_tok_sh_wr = - (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; - int s_tok_sh_rd = (threadIdx.x % 32) / 4; - bool s_tok_sh_wr_pred = threadIdx.x < prob_m; - - auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - auto s_ch_sh_wr = threadIdx.x; - int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - 2 * ((threadIdx.x % 32) % 4); - bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride; - - int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd; - bool s_group_sh_wr_pred; - if constexpr (group_blocks != -1) { - s_group_gl_rd = - s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_group_sh_stride * slice_col + threadIdx.x; - s_group_sh_wr = threadIdx.x; - // NOTE(HandH1998): s_group_sh_rd is related to mma output C - s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride; - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages * - // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s_tok = sh_b + (stages * b_sh_stage); - int4* sh_s_ch = sh_s_tok + s_tok_sh_stride; - int4* sh_s_group = sh_s_ch + s_ch_sh_stride; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS_GROUP frag_s_group[2][4]; - FragS_CHANNEL frag_s_tok[thread_m_blocks]; - FragS_CHANNEL frag_s_ch[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe; - if (s_group_sh_wr_pred) - cp_async4(&sh_s_group_stage[s_group_sh_wr], - &s_group[s_group_gl_rd]); - s_group_gl_rd += s_group_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - int4* sh_s_group_stage = - sh_s_group + - s_group_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s_group[k % 2])[0] = - sh_s_group_stage[s_group_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - // int b_quant_shift = b_quant << 4; - FragB frag_b0, frag_b1; - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0); - frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1); - } else { - int b_quant_shift = b_quant << 4; - frag_b0 = dequant_per_channel(b_quant); - frag_b1 = dequant_per_channel(b_quant_shift); - } - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - int* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - // global_reduce works on INT32 elements, which are the results of INT8 GEMM. - // This is why we need another INT32 maxtrix `C` to reduce instead of the - // original half matrix `D`. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 4; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 8 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; - c_gl_wr += (4 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads * 2; - auto c_sh_wr = 2 * threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i + 1], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2) + 1], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; - int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - reinterpret_cast(&d_red1)[j]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += - reinterpret_cast(&d_red2)[j]; - } - } - if (!last) { - int4 d1, d2; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d1)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d2)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - d1; - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + - 1] = d2; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int d_gl_stride = prob_n / 8; - constexpr int d_sh_stride = 2 * thread_n_blocks + 1; - int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int d_sh_rd_delta = - d_sh_stride * (threads / (2 * thread_n_blocks)); - - int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - d_gl_wr += (2 * thread_n_blocks) * slice_col; - int d_sh_wr = - (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - d_sh_wr += 32 * (threadIdx.x / 32); - int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int d_gl_wr_end = d_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { - float2 deq_res; - deq_res.x = int32_to_float(c0) * w_s[0] * a_s; - deq_res.y = int32_to_float(c1) * w_s[1] * a_s; - ((half2*)sh)[idx] = float2_to_half2(deq_res); - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = d_sh_wr + 8 * j; - write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - } - d_sh_wr += 16 * (4 * d_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (d_gl_wr < d_gl_wr_end) { - D[d_gl_wr] = sh[d_sh_rd]; - d_gl_wr += d_gl_wr_delta; - d_sh_rd += d_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (last) { - if (s_tok_sh_wr_pred) { - cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]); - } - if (s_ch_sh_wr_pred) { - cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]); - } - cp_async_fence(); - } - thread_block_reduce(); - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - frag_s_tok[i][0] = - *reinterpret_cast(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]); - frag_s_tok[i][1] = *reinterpret_cast( - &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]); - } - reinterpret_cast(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0]; - reinterpret_cast(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1]; - reinterpret_cast(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8]; - reinterpret_cast(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x; - s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \ - prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D, - void* s_tok, void* s_ch, void* s_group, int prob_m, - int prob_n, int prob_k, void* workspace, - int groupsize = -1, int dev = 0, cudaStream_t stream = 0, - int thread_k = -1, int thread_n = -1, int sms = -1, - int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - int4* D_ptr = (int4*)D; - const float* s_tok_ptr = (const float*)s_tok; - const int4* s_ch_ptr = (const int4*)s_ch; - const int4* s_group_ptr = (const int4*)s_group; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par; - D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - s_tok_ptr += 16 * thread_m_blocks * par; - } -} -} // anonymous namespace - -torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, - torch::Tensor const& b_q_weight, - torch::Tensor const& s_tok, - torch::Tensor const& s_ch, - torch::Tensor const& s_group, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - TORCH_CHECK(size_m == s_tok.numel(), - "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(tile_size)); - TORCH_CHECK( - (size_k / tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + - ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size)); - - int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0); - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify N - TORCH_CHECK(s_ch.numel() == size_n, - "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) + - ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(tile_size)); - if (groupsize != -1) { - TORCH_CHECK(s_group.size(1) == size_n, - "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - size_k % s_group.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by s_group.size(0) = " + str(s_group.size(0))); - } - - int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "Shape mismatch: size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify s_tok device, strides and dtype - TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU"); - TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous"); - TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32"); - - // Verify s_ch device, strides and dtype - TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU"); - TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous"); - TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32"); - - // Verify s_group device, strides and dtype - TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU"); - TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous"); - TORCH_CHECK(s_group.dtype() == torch::kFloat16, - "s_group's dtype is not float16"); - - // Verify workspace size - TORCH_CHECK(size_n % min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(min_thread_n)); - int min_workspace_size = (size_n / min_thread_n) * max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c); - - // Alloc D matrix - auto options_d = - torch::TensorOptions().dtype(torch::kFloat16).device(a.device()); - torch::Tensor d = torch::empty({size_m, size_n}, options_d); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - int dev = a.get_device(); - marlin_qqq_cuda( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(), - s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par); - - return d; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_qqq_gemm", &marlin_qqq_gemm); -} diff --git a/csrc/quantization/vectorization_utils.cuh b/csrc/quantization/vectorization_utils.cuh index 8aa0147df6ba8..98b491b7e23fc 100644 --- a/csrc/quantization/vectorization_utils.cuh +++ b/csrc/quantization/vectorization_utils.cuh @@ -41,8 +41,10 @@ __device__ inline void vectorize_with_alignment( for (int i = tid; i < num_vec; i += stride) { vout_t tmp; - vec_op(tmp, v_in[i]); - v_out[i] = tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st } return; } @@ -71,8 +73,10 @@ __device__ inline void vectorize_with_alignment( // 2. vectorize the main part for (int i = tid; i < num_vec; i += stride) { vout_t tmp; - vec_op(tmp, v_in[i]); - v_out[i] = tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st } // 3. handle the tail @@ -125,7 +129,8 @@ __device__ inline void vectorize_read_with_alignment(const InT* in, int len, auto* v_in = reinterpret_cast(in); for (int i = tid; i < num_vec; i += stride) { - vec_op(v_in[i]); + vin_t tmp = v_in[i]; + vec_op(tmp); } return; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7079671c2eb16..4edb7af50f102 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -241,14 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // custom types: // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA - // Marlin (Dense) Optimized Quantized GEMM for GPTQ. - ops.def( - "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " - "Tensor", - {stride_tag}); - // conditionally compiled so impl in source file - // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " @@ -353,15 +345,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM - // marlin_qqq_gemm for QQQ. - ops.def( - "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " - "Tensor s_tok, Tensor s_ch, Tensor s_group, " - "Tensor! workspace, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor", - {stride_tag}); - // conditionally compiled so impl registration is in source file - // CUTLASS nvfp4 block scaled GEMM ops.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," @@ -440,6 +423,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); + // A function that computes problem sizes for each expert's multiplication + // used by the two mms called from fused MoE operation. It takes topk_ids as + // an input, and computes problem_sizes1 and problem_sizes2 only. + ops.def( + "get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, " + " Tensor! problem_sizes1, " + " Tensor! problem_sizes2, " + " int num_experts, int n, int k, " + " Tensor? blockscale_offsets) -> ()", + {stride_tag}); + ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, + &get_cutlass_moe_mm_problem_sizes); + // A function that computes data required to run fused MoE with w8a8 grouped // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs // as an input, and computes expert_offsets (token start indices of each @@ -676,11 +672,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); - // Gather cache blocks from src_cache to dst. + // Gather cache blocks from src_cache to dst, dequantizing from + // src_cache's dtype to dst's dtype if necessary. cache_ops.def( - "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " - "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); - cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); + "gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, " + " Tensor block_table, Tensor cu_seq_lens, " + " int batch_size, " + " str kv_cache_dtype, " + " Tensor scale, Tensor? seq_starts) -> ()"); + cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, + &gather_and_maybe_dequant_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/docker/Dockerfile b/docker/Dockerfile index 74938917781ac..cfaa59868215c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -372,31 +372,45 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt -# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel. -ARG FLASHINFER_GIT_REF="v0.2.11" +# Keep this in sync with "flashinfer" extra in setup.py +ARG FLASHINFER_GIT_REF="v0.2.12" +# Flag to control whether to compile FlashInfer AOT kernels +# Set to "true" to enable AOT compilation: +# docker build --build-arg FLASHINFER_AOT_COMPILE=true ... +ARG FLASHINFER_AOT_COMPILE=false RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' . /etc/environment git clone --depth 1 --recursive --shallow-submodules \ --branch ${FLASHINFER_GIT_REF} \ ${FLASHINFER_GIT_REPO} flashinfer - # Exclude CUDA arches for older versions (11.x and 12.0-12.7) - # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. - if [[ "${CUDA_VERSION}" == 11.* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" - elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" - else - # CUDA 12.8+ supports 10.0a and 12.0 - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" - fi - echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}" - # Needed to build AOT kernels pushd flashinfer - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer.aot - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - uv pip install --system --no-build-isolation --force-reinstall --no-deps . + if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then + # Exclude CUDA arches for older versions (11.x and 12.0-12.7) + # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. + if [[ "${CUDA_VERSION}" == 11.* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" + elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" + else + # CUDA 12.8+ supports 10.0a and 12.0 + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" + fi + echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" + # Build AOT kernels + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer.aot + # Install with no-build-isolation since we already built AOT kernels + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + uv pip install --system --no-build-isolation . \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + # Download pre-compiled cubins + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." + else + echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode" + uv pip install --system . \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + fi popd rm -rf flashinfer BASH diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 4f40f32a39f26..f164857325043 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install lm-eval[api]==0.4.4 \ + && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ && python3 -m pip install pytest-shard # ----------------------- diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 4e89bb3057c5e..9270b48c54d4b 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -16,7 +16,7 @@ ENV LANG=C.UTF-8 \ RUN microdnf install -y \ which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile && \ microdnf clean all # Python Installation @@ -136,6 +136,71 @@ RUN --mount=type=cache,target=/root/.cache/uv \ mkdir -p /tmp/hf-xet/dist && \ cp dist/*.whl /tmp/hf-xet/dist/ +# Build numba +FROM python-install AS numba-builder + +ARG MAX_JOBS +ARG NUMBA_VERSION=0.61.2 + +WORKDIR /tmp + +# Clone all required dependencies +RUN --mount=type=cache,target=/root/.cache/uv \ + microdnf install ninja-build gcc gcc-c++ -y && \ + git clone --recursive https://github.com/llvm/llvm-project.git -b llvmorg-15.0.7 && \ + git clone --recursive https://github.com/numba/llvmlite.git -b v0.44.0 && \ + git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \ + cd llvm-project && mkdir build && cd build && \ + uv pip install 'cmake<4' setuptools numpy && \ + export PREFIX=/usr/local && CMAKE_ARGS="${CMAKE_ARGS} -DLLVM_ENABLE_PROJECTS=lld;libunwind;compiler-rt" \ + CFLAGS="$(echo $CFLAGS | sed 's/-fno-plt //g')" \ + CXXFLAGS="$(echo $CXXFLAGS | sed 's/-fno-plt //g')" \ + CMAKE_ARGS="${CMAKE_ARGS} -DFFI_INCLUDE_DIR=$PREFIX/include" \ + CMAKE_ARGS="${CMAKE_ARGS} -DFFI_LIBRARY_DIR=$PREFIX/lib" \ + cmake -DCMAKE_INSTALL_PREFIX="${PREFIX}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_LIBRARY_PATH="${PREFIX}" \ + -DLLVM_ENABLE_LIBEDIT=OFF \ + -DLLVM_ENABLE_LIBXML2=OFF \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_ENABLE_TERMINFO=OFF \ + -DLLVM_INCLUDE_BENCHMARKS=OFF \ + -DLLVM_INCLUDE_DOCS=OFF \ + -DLLVM_INCLUDE_EXAMPLES=OFF \ + -DLLVM_INCLUDE_GO_TESTS=OFF \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_UTILS=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DLLVM_UTILS_INSTALL_DIR=libexec/llvm \ + -DLLVM_BUILD_LLVM_DYLIB=OFF \ + -DLLVM_LINK_LLVM_DYLIB=OFF \ + -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD=WebAssembly \ + -DLLVM_ENABLE_FFI=ON \ + -DLLVM_ENABLE_Z3_SOLVER=OFF \ + -DLLVM_OPTIMIZED_TABLEGEN=ON \ + -DCMAKE_POLICY_DEFAULT_CMP0111=NEW \ + -DCOMPILER_RT_BUILD_BUILTINS=ON \ + -DCOMPILER_RT_BUILTINS_HIDE_SYMBOLS=OFF \ + -DCOMPILER_RT_BUILD_LIBFUZZER=OFF \ + -DCOMPILER_RT_BUILD_CRT=OFF \ + -DCOMPILER_RT_BUILD_MEMPROF=OFF \ + -DCOMPILER_RT_BUILD_PROFILE=OFF \ + -DCOMPILER_RT_BUILD_SANITIZERS=OFF \ + -DCOMPILER_RT_BUILD_XRAY=OFF \ + -DCOMPILER_RT_BUILD_GWP_ASAN=OFF \ + -DCOMPILER_RT_BUILD_ORC=OFF \ + -DCOMPILER_RT_INCLUDE_TESTS=OFF \ + ${CMAKE_ARGS} -GNinja ../llvm \ + + && ninja install . && \ + # build llvmlite + cd ../../llvmlite && python setup.py bdist_wheel && \ + cd ../numba && \ + if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \ + sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ + fi && python setup.py bdist_wheel + + # Final build stage FROM python-install AS vllm-cpu ARG PYTHON_VERSION @@ -163,23 +228,30 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ --mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \ --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ + --mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \ + --mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \ sed -i '/^torch/d' requirements/build.txt && \ - ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ - VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ - HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \ - TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ + ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \ + VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \ + HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl) && \ + TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \ + LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \ + NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ $HF_XET_WHL_FILE \ $TORCH_WHL_FILE \ + $LLVM_WHL_FILE \ + $NUMBA_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ - -r requirements/cpu.txt + -r requirements/cpu.txt + # Build and install vllm RUN --mount=type=cache,target=/root/.cache/uv \ - VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ + VLLM_TARGET_DEVICE=cpu VLLM_CPU_MOE_PREPACK=0 python setup.py bdist_wheel && \ uv pip install "$(echo dist/*.whl)[tensorizer]" # setup non-root user for vllm @@ -196,4 +268,3 @@ WORKDIR /home/vllm # Set the default entrypoint ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] - diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu index 2190151369761..ca2d7833c1efa 100644 --- a/docker/Dockerfile.tpu +++ b/docker/Dockerfile.tpu @@ -7,7 +7,8 @@ WORKDIR /workspace/vllm # Install some basic utilities RUN apt-get update && apt-get install -y \ git \ - ffmpeg libsm6 libxext6 libgl1 + ffmpeg libsm6 libxext6 libgl1 && \ + rm -rf /var/lib/apt/lists/* # Build vLLM. COPY . . @@ -16,6 +17,9 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi # Remove existing versions of dependencies +# TODO: These packages will remain as dead weight in the Docker image layers. +# We should find a way to build the image without uninstalling these. +# Consider using a different base image. RUN pip uninstall -y torch torch_xla torchvision ENV VLLM_TARGET_DEVICE="tpu" @@ -23,9 +27,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 -m pip install \ -r requirements/tpu.txt -RUN python3 -m pip install -e . + +RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e . # install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils +RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e tests/vllm_test_utils CMD ["/bin/bash"] diff --git a/docs/api/README.md b/docs/api/README.md index 327472df1d52c..57142e8f5625d 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -77,6 +77,7 @@ Internal data structures. - [vllm.multimodal.inputs.MultiModalFieldElem][] - [vllm.multimodal.inputs.MultiModalFieldConfig][] - [vllm.multimodal.inputs.MultiModalKwargsItem][] +- [vllm.multimodal.inputs.MultiModalKwargsItems][] - [vllm.multimodal.inputs.MultiModalKwargs][] - [vllm.multimodal.inputs.MultiModalInputs][] diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 2eeb8ad25de5f..357a5eb594060 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -48,7 +48,7 @@ You can tune the performance by adjusting `max_num_batched_tokens`: - Smaller values (e.g., 2048) achieve better inter-token latency (ITL) because there are fewer prefills slowing down decodes. - Higher values achieve better time to first token (TTFT) as you can process more prefill tokens in a batch. -- For optimal throughput, we recommend setting `max_num_batched_tokens > 8096` especially for smaller models on large GPUs. +- For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs. - If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes). ```python @@ -129,6 +129,52 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. +### Batch-level DP for Multi-Modal Encoders + +By default, TP is used to shard the weights of multi-modal encoders just like for language decoders, +in order to reduce the memory and compute load on each GPU. + +However, since the size of multi-modal encoders is very small compared to language decoders, +there is relatively little gain from TP. On the other hand, TP incurs significant communication +overhead because of all-reduce being performed after every layer. + +Given this, it may be advantageous to instead shard the batched input data using TP, essentially +performing batch-level DP. This has been shown to improve the throughput by around 10% for +`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations, +batch-level DP can provide another 40% increase to throughput compared to regular TP. + +Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank, +there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already. + +You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example: + +```python +from vllm import LLM + +llm = LLM( + model="Qwen/Qwen2.5-VL-72B-Instruct", + tensor_parallel_size=4, + # When mm_encoder_tp_mode="data", + # the vision encoder uses TP=4 (not DP=1) to shard the input data, + # so the TP size becomes the effective DP size. + # Note that this is independent of the DP size for language decoder which is used in expert parallel setting. + mm_encoder_tp_mode="data", + # The language decoder uses TP=4 to shard the weights regardless + # of the setting of mm_encoder_tp_mode +) +``` + +!! important + Batch-level DP is not to be confused with API request-level DP + (which is instead controlled by `data_parallel_size`). + +The availablilty of batch-level DP is based on model implementation. +Currently, the following models support `mm_encoder_tp_mode="data"`: + +- Llama4 () +- Qwen2.5-VL () +- Step3 () + ## Input Processing ### Parallel Processing diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 64a48be32645a..76d0f067fd452 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -629,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -778,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id diff --git a/docs/deployment/frameworks/dstack.md b/docs/deployment/frameworks/dstack.md index 23dc58c974ed8..fe4d87f78f2aa 100644 --- a/docs/deployment/frameworks/dstack.md +++ b/docs/deployment/frameworks/dstack.md @@ -9,7 +9,7 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/), To install dstack client, run: ```bash -pip install "dstack[all] +pip install dstack[all] dstack server ``` diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md index 0661933acd61f..834c03cbe05b0 100644 --- a/docs/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio Install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` Load and run the model in `vllm`: diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md index 127e403989944..d6fdac7b07f7f 100644 --- a/docs/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -18,7 +18,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index 45fae58a64868..247d0cbdd3f14 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -19,7 +19,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index e8ed2155375d4..047cc8382445b 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -20,7 +20,7 @@ for more installation details. Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index f833807666460..2af26626d207d 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -8,7 +8,7 @@ This guide will help you quickly get started with vLLM to perform: ## Prerequisites - OS: Linux -- Python: 3.9 -- 3.12 +- Python: 3.9 -- 3.13 ## Installation diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a514572945c3f..ad3db1cf2100f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -363,7 +363,7 @@ th { | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | | `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ | | `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ | @@ -373,6 +373,7 @@ th { | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | @@ -384,8 +385,8 @@ th { | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | @@ -436,17 +437,17 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertModel`C | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | -| `Gemma2Model`C | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | -| `GteModel`C | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | | -| `GteNewModel`C | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | | -| `ModernBertModel`C | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | | -| `NomicBertModel`C | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | | +| `BertModel`C | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ | +| `Gemma2Model`C | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | +| `GteModel`C | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ | +| `GteNewModel`C | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ | +| `ModernBertModel`C | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | ✅︎ | +| `NomicBertModel`C | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | ✅︎ | | `LlamaModel`C, `LlamaForCausalLM`C, `MistralModel`C, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2Model`C, `Qwen2ForCausalLM`C | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3Model`C, `Qwen3ForCausalLM`C | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | ✅︎ | | `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* | C Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion)) @@ -476,7 +477,7 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | | `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* | @@ -493,12 +494,12 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | | +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | | -| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | ✅︎ | | `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* | C Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) @@ -626,7 +627,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | | `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + IE+ | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | | `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + IE+ | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | @@ -641,6 +642,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -651,6 +653,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | +| `RForConditionalGeneration` | R-VL-4B | T + IE+ | `YannQi/R-4B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ | ✅︎ | diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 54af970ea842d..b89768913681e 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -107,7 +107,7 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that @@ -154,12 +154,15 @@ differences compared to V0: ##### Logprobs Calculation -Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. +By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. before applying any logits post-processing such as temperature scaling or penalty adjustments). As a result, the returned logprobs do not reflect the final adjusted probabilities used during sampling. -Support for logprobs with post-sampling adjustments is in progress and will be added in future updates. +You can adjust this behavior by setting the `--logprobs-mode` flag. +Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`. +Raw means the values before applying any logit processors, like bad words. +Processed means the values after applying all processors, including temperature and top_k/top_p. ##### Prompt Logprobs with Prefix Caching diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 184c30891eca7..c4972f02d0f8e 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt from vllm.v1.metrics.reader import Counter, Vector try: @@ -137,7 +138,8 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - prompt_token_ids=prompt_ids, sampling_params=sampling_params + TokensPrompt(prompt_token_ids=prompt_ids), + sampling_params=sampling_params, ) else: outputs = llm.chat(prompts, sampling_params=sampling_params) diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index f46064931dbac..88d87beb4874d 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -85,7 +85,7 @@ def format_output(title: str, output: str): def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM): - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) return outputs[0].outputs[0].text diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 988ad35cdd7e6..8d97ba2668263 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -283,8 +283,10 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" + ( + "<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>" + f"{question}<|assistant|>" + ) for question in questions ] @@ -333,6 +335,80 @@ def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: ) +# GLM-4.5V +def run_glm4_5v(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# GLM-4.5V-FP8 +def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V-FP8" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # H2OVL-Mississippi def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -383,8 +459,8 @@ def run_hyperclovax_seed_vision( for question in questions: if modality == "image": """ - ocr: List the words in the image in raster order. - Even if the word order feels unnatural for reading, + ocr: List the words in the image in raster order. + Even if the word order feels unnatural for reading, the model will handle it as long as it follows raster order. e.g. "Naver, CLOVA, bigshane" lens_keywords: List the entity names in the image. @@ -693,15 +769,13 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData: if modality == "video": prompts = [ - f"<|im_start|>user