diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py
index 7045d8810493e..bbed80ebe8476 100644
--- a/.buildkite/generate_index.py
+++ b/.buildkite/generate_index.py
@@ -8,7 +8,8 @@ template = """
Links for vLLM
- {wheel}
+ {x86_wheel}
+ {arm_wheel}
"""
@@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel)
with open("index.html", "w") as f:
print(f"Generated index.html for {args.wheel}")
+ # sync the abi tag with .buildkite/scripts/upload-wheels.sh
+ if "x86_64" in filename:
+ x86_wheel = filename
+ arm_wheel = filename.replace("x86_64", "aarch64").replace(
+ "manylinux1", "manylinux2014"
+ )
+ elif "aarch64" in filename:
+ x86_wheel = filename.replace("aarch64", "x86_64").replace(
+ "manylinux2014", "manylinux1"
+ )
+ arm_wheel = filename
+ else:
+ raise ValueError(f"Unsupported wheel: {filename}")
# cloudfront requires escaping the '+' character
f.write(
- template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
+ template.format(
+ x86_wheel=x86_wheel,
+ x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
+ arm_wheel=arm_wheel,
+ arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
+ )
)
diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
deleted file mode 100644
index 56ec933c9cc0e..0000000000000
--- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
+++ /dev/null
@@ -1,12 +0,0 @@
-# For vllm script, with -t option (tensor parallel size).
-# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
-model_name: "HandH1998/QQQ-Llama-3-8b-g128"
-tasks:
-- name: "gsm8k"
- metrics:
- - name: "exact_match,strict-match"
- value: 0.419
- - name: "exact_match,flexible-extract"
- value: 0.416
-limit: 1000
-num_fewshot: 5
diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt
index 27a1a9a82bd35..37eeac85c933b 100644
--- a/.buildkite/lm-eval-harness/configs/models-large.txt
+++ b/.buildkite/lm-eval-harness/configs/models-large.txt
@@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
DeepSeek-V2-Lite-Chat.yaml
-Meta-Llama-3-8B-QQQ.yaml
diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml
index e20ce54ca795a..f96c38bf57db7 100644
--- a/.buildkite/release-pipeline.yaml
+++ b/.buildkite/release-pipeline.yaml
@@ -27,7 +27,12 @@ steps:
env:
DOCKER_BUILDKIT: "1"
+ - block: "Build CUDA 12.6 wheel"
+ key: block-build-cu126-wheel
+ depends_on: ~
+
- label: "Build wheel - CUDA 12.6"
+ depends_on: block-build-cu126-wheel
id: build-wheel-cuda-12-6
agents:
queue: cpu_queue_postmerge
diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh
index 57a7bc4e5f5df..9dec9f8e9eb32 100644
--- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh
@@ -46,6 +46,11 @@ function cpu_tests() {
set -e
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
+ # Run kernel tests
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pytest -v -s tests/kernels/test_onednn.py"
+
# Run basic model test
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
@@ -99,4 +104,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 40 mins.
export -f cpu_tests
-timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
+timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
diff --git a/.buildkite/scripts/tpu/cleanup_docker.sh b/.buildkite/scripts/tpu/cleanup_docker.sh
index 209d9c4341cdd..740d81fb39bb0 100755
--- a/.buildkite/scripts/tpu/cleanup_docker.sh
+++ b/.buildkite/scripts/tpu/cleanup_docker.sh
@@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then
# Remove dangling images (those that are not tagged and not used by any container)
docker image prune -f
# Remove unused volumes / force the system prune for old images as well.
- docker volume prune -f && docker system prune --force --filter "until=72h" --all
+ docker volume prune -f && docker system prune --force --filter "until=24h" --all
echo "Docker images and volumes cleanup completed."
else
echo "Disk usage is below $threshold%. No cleanup needed."
diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh
index 037897e53dbef..745f285c008ad 100644
--- a/.buildkite/scripts/upload-wheels.sh
+++ b/.buildkite/scripts/upload-wheels.sh
@@ -14,8 +14,19 @@ fi
# Get the single wheel file
wheel="${wheel_files[0]}"
-# Rename 'linux' to 'manylinux1' in the wheel filename
-new_wheel="${wheel/linux/manylinux1}"
+# Detect architecture and rename 'linux' to appropriate manylinux version
+arch=$(uname -m)
+if [[ $arch == "x86_64" ]]; then
+ manylinux_version="manylinux1"
+elif [[ $arch == "aarch64" ]]; then
+ manylinux_version="manylinux2014"
+else
+ echo "Warning: Unknown architecture $arch, using manylinux1 as default"
+ manylinux_version="manylinux1"
+fi
+
+# Rename 'linux' to the appropriate manylinux version in the wheel filename
+new_wheel="${wheel/linux/$manylinux_version}"
mv -- "$wheel" "$new_wheel"
wheel="$new_wheel"
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 2f7f1db75bfb9..1f67e7e92bd11 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -328,6 +328,7 @@ steps:
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py
+ - pytest -v -s compile/test_decorator.py
- label: PyTorch Fullgraph Smoke Test # 9min
mirror_hardwares: [amdexperimental]
@@ -341,6 +342,7 @@ steps:
- pytest -v -s compile/piecewise/test_simple.py
- pytest -v -s compile/piecewise/test_toy_llama.py
- pytest -v -s compile/piecewise/test_full_cudagraph.py
+ - pytest -v -s compile/piecewise/test_multiple_graphs.py
- label: PyTorch Fullgraph Test # 18min
mirror_hardwares: [amdexperimental]
@@ -543,6 +545,15 @@ steps:
commands:
- pytest -v -s models/language/pooling -m 'not core_model'
+- label: Multi-Modal Processor Test
+ source_file_dependencies:
+ - vllm/
+ - tests/models/multimodal
+ commands:
+ - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
+ - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
+ - pytest -v -s models/multimodal/processing/test_tensor_schema.py
+
- label: Multi-Modal Models Test (Standard)
mirror_hardwares: [amdexperimental]
torch_nightly: true
@@ -552,9 +563,7 @@ steps:
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pip freeze | grep -E 'torch'
- - pytest -v -s models/multimodal/processing
- - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
- - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
+ - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Models Test (Extended) 1
@@ -565,7 +574,7 @@ steps:
- tests/models/multimodal
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
+ - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing
- label: Multi-Modal Models Test (Extended) 2
mirror_hardwares: [amdexperimental]
@@ -646,6 +655,7 @@ steps:
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
+ - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml
deleted file mode 100644
index 2b1086b7faf43..0000000000000
--- a/.github/workflows/lint-and-deploy.yaml
+++ /dev/null
@@ -1,89 +0,0 @@
-name: Lint and Deploy Charts
-
-on: pull_request
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: true
-
-permissions:
- contents: read
-
-jobs:
- lint-and-deploy:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- fetch-depth: 0
-
- - name: Set up Helm
- uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0
- with:
- version: v3.14.4
-
- #Python is required because ct lint runs Yamale and yamllint which require Python.
- - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
- with:
- python-version: '3.13'
-
- - name: Set up chart-testing
- uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0
- with:
- version: v3.10.1
-
- - name: Run chart-testing (lint)
- run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm
-
- - name: Setup minio
- run: |
- docker network create vllm-net
- docker run -d -p 9000:9000 --name minio --net vllm-net \
- -e "MINIO_ACCESS_KEY=minioadmin" \
- -e "MINIO_SECRET_KEY=minioadmin" \
- -v /tmp/data:/data \
- -v /tmp/config:/root/.minio \
- minio/minio server /data
- export AWS_ACCESS_KEY_ID=minioadmin
- export AWS_SECRET_ACCESS_KEY=minioadmin
- export AWS_EC2_METADATA_DISABLED=true
- mkdir opt-125m
- cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd ..
- aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket
- aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive
-
- - name: Create kind cluster
- uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0
-
- - name: Build the Docker image vllm cpu
- run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env .
-
- - name: Configuration of docker images, network and namespace for the kind cluster
- run: |
- docker pull amazon/aws-cli:2.6.4
- kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing
- kind load docker-image vllm-cpu-env:latest --name chart-testing
- docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")"
- kubectl create ns ns-vllm
-
- - name: Run chart-testing (install)
- run: |
- export AWS_ACCESS_KEY_ID=minioadmin
- export AWS_SECRET_ACCESS_KEY=minioadmin
- sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
- helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
-
- - name: curl test
- run: |
- kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 &
- sleep 10
- CODE="$(curl -v -f --location http://localhost:8001/v1/completions \
- --header "Content-Type: application/json" \
- --data '{
- "model": "opt-125m",
- "prompt": "San Francisco is a",
- "max_tokens": 7,
- "temperature": 0
- }'):$CODE"
- echo "$CODE"
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
deleted file mode 100644
index bfd02879965ee..0000000000000
--- a/.github/workflows/publish.yml
+++ /dev/null
@@ -1,111 +0,0 @@
-# This workflow will upload a Python Package to Release asset
-# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
-
-name: Create Release
-
-on:
- push:
- tags:
- - v*
-
-# Needed to create release and upload assets
-permissions:
- contents: write
-
-jobs:
- release:
- # Retrieve tag and create release
- name: Create Release
- runs-on: ubuntu-latest
- outputs:
- upload_url: ${{ steps.create_release.outputs.upload_url }}
- steps:
- - name: Checkout
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
-
- - name: Extract branch info
- shell: bash
- run: |
- echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV"
-
- - name: Create Release
- id: create_release
- uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
- env:
- RELEASE_TAG: ${{ env.release_tag }}
- with:
- github-token: "${{ secrets.GITHUB_TOKEN }}"
- script: |
- const script = require('.github/workflows/scripts/create_release.js')
- await script(github, context, core)
-
- # NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow.
- # wheel:
- # name: Build Wheel
- # runs-on: ${{ matrix.os }}
- # needs: release
-
- # strategy:
- # fail-fast: false
- # matrix:
- # os: ['ubuntu-20.04']
- # python-version: ['3.9', '3.10', '3.11', '3.12']
- # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt.
- # cuda-version: ['11.8', '12.1']
-
- # steps:
- # - name: Checkout
- # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
-
- # - name: Setup ccache
- # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
- # with:
- # create-symlink: true
- # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
-
- # - name: Set up Linux Env
- # if: ${{ runner.os == 'Linux' }}
- # run: |
- # bash -x .github/workflows/scripts/env.sh
-
- # - name: Set up Python
- # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- # with:
- # python-version: ${{ matrix.python-version }}
-
- # - name: Install CUDA ${{ matrix.cuda-version }}
- # run: |
- # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
-
- # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
- # run: |
- # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
-
- # - name: Build wheel
- # shell: bash
- # env:
- # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
- # run: |
- # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
- # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename)
- # asset_name=${wheel_name//"linux"/"manylinux1"}
- # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV"
- # echo "asset_name=${asset_name}" >> "$GITHUB_ENV"
-
- # - name: Upload Release Asset
- # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2
- # env:
- # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- # with:
- # upload_url: ${{ needs.release.outputs.upload_url }}
- # asset_path: ./dist/${{ env.wheel_name }}
- # asset_name: ${{ env.asset_name }}
- # asset_content_type: application/*
-
- # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
- # - name: Publish package
- # uses: pypa/gh-action-pypi-publish@release/v1.8
- # with:
- # repository-url: https://test.pypi.org/legacy/
- # password: ${{ secrets.PYPI_API_TOKEN }}
- # skip-existing: true
diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml
index 16ae1aadb96be..1ee605dc7bb0d 100644
--- a/.github/workflows/reminder_comment.yml
+++ b/.github/workflows/reminder_comment.yml
@@ -12,16 +12,43 @@ jobs:
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
script: |
- github.rest.issues.createComment({
- owner: context.repo.owner,
- repo: context.repo.repo,
- issue_number: context.issue.number,
- body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
- '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
- 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' +
- 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
- 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
- '🚀'
- })
+ try {
+ // Get the PR author
+ const prAuthor = context.payload.pull_request.user.login;
+
+ // Check if this is the author's first PR in this repository
+ // Use GitHub's search API to find all PRs by this author
+ const { data: searchResults } = await github.rest.search.issuesAndPullRequests({
+ q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`,
+ per_page: 100
+ });
+
+ const authorPRCount = searchResults.total_count;
+
+ console.log(`Found ${authorPRCount} PRs by ${prAuthor}`);
+
+ // Only post comment if this is the first PR (only one PR by this author)
+ if (authorPRCount === 1) {
+ console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`);
+ await github.rest.issues.createComment({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: context.issue.number,
+ body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
+ '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
+ 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' +
+ 'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' +
+ 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
+ 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
+ 'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' +
+ '🚀'
+ });
+ } else {
+ console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`);
+ }
+ } catch (error) {
+ console.error('Error checking PR history or posting comment:', error);
+ // Don't fail the workflow, just log the error
+ }
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/CMakeLists.txt b/CMakeLists.txt
index bcbd1b52a06c6..a1deefb07f09c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -357,9 +357,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
set(MARLIN_SRCS
- "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
- "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index e1a856026c4ae..2ea4f9ccaff2b 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -958,8 +958,10 @@ class InstructCoderDataset(HuggingFaceDataset):
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
- prompt = f"{item['input']}\n\n{item['instruction']} Just output \
- the code, do not include any explanation."
+ prompt = (
+ f"{item['input']}\n\n{item['instruction']} Just output "
+ "the code, do not include any explanation."
+ )
# apply template
prompt = tokenizer.apply_chat_template(
diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
index 1d4e730f99ae9..a6b42406b5cb0 100644
--- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
@@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False
)
+ ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
+ ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
+ c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
+ c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
+
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@@ -111,6 +116,10 @@ def bench_run(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
+ ab_strides1: torch.Tensor,
+ ab_strides2: torch.Tensor,
+ c_strides1: torch.Tensor,
+ c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
@@ -125,6 +134,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
per_act_token,
a1_scale=None,
)
@@ -136,6 +149,10 @@ def bench_run(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
+ ab_strides1: torch.Tensor,
+ ab_strides2: torch.Tensor,
+ c_strides1: torch.Tensor,
+ c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
@@ -150,6 +167,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
per_act_token,
a1_scale=None,
)
@@ -194,6 +215,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
topk_weights,
topk_ids,
)
@@ -231,6 +256,10 @@ def bench_run(
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
+ "ab_strides1": ab_strides1,
+ "ab_strides2": ab_strides2,
+ "c_strides1": c_strides1,
+ "c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@@ -289,6 +318,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
topk_weights,
topk_ids,
per_act_token,
@@ -297,7 +330,7 @@ def bench_run(
results.append(
benchmark.Timer(
- stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
+ stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py
index 975d10f2e92ec..a9c4d30d9b189 100644
--- a/benchmarks/kernels/benchmark_machete.py
+++ b/benchmarks/kernels/benchmark_machete.py
@@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
else:
assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8
-
- if bt.w_ch_s is not None:
- s_ch = bt.w_ch_s.to(torch.float32)
- else:
- s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
-
- if bt.w_tok_s is not None:
- s_tok = bt.w_tok_s.to(torch.float32)
- else:
- s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
-
- fn = lambda: ops.marlin_qqq_gemm(
- a=bt.a,
- b_q_weight=w_q,
- s_group=w_s,
- s_tok=s_tok,
- s_ch=s_ch,
- workspace=workspace.scratch,
- size_m=bt.a.shape[0],
- size_n=bt.w_ref.shape[1],
- size_k=bt.w_ref.shape[0],
- )
+ raise NotImplementedError("QQQ is not supported anymore")
return fn
diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
new file mode 100644
index 0000000000000..0650cbf3cc18e
--- /dev/null
+++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import time
+
+import torch
+
+from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
+ silu_mul_fp8_quant_deep_gemm,
+)
+from vllm.platforms import current_platform
+
+
+def benchmark(E, T, H, G=128, runs=50):
+ current_platform.seed_everything(42)
+ y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
+ tokens_per_expert = torch.randint(
+ T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
+ )
+
+ # Warmup
+ for _ in range(10):
+ silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
+ torch.cuda.synchronize()
+
+ # Benchmark
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+ for _ in range(runs):
+ silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
+ torch.cuda.synchronize()
+
+ avg_time = (time.perf_counter() - start) / runs * 1000
+
+ # Calculate actual work done (only count valid tokens)
+ actual_tokens = tokens_per_expert.sum().item()
+ actual_elements = actual_tokens * H
+
+ # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
+ ops_per_element = 8
+ total_ops = actual_elements * ops_per_element
+ gflops = total_ops / (avg_time / 1000) / 1e9
+
+ # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
+ input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
+ output_bytes = actual_tokens * H * 1 # H fp8 outputs
+ scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
+ total_bytes = input_bytes + output_bytes + scale_bytes
+ memory_bw = total_bytes / (avg_time / 1000) / 1e9
+
+ return avg_time, gflops, memory_bw
+
+
+configs = [
+ (8, 32, 1024),
+ (16, 64, 2048),
+ (32, 128, 4096),
+ # DeepSeekV3 Configs
+ (256, 16, 7168),
+ (256, 32, 7168),
+ (256, 64, 7168),
+ (256, 128, 7168),
+ (256, 256, 7168),
+ (256, 512, 7168),
+ (256, 1024, 7168),
+]
+
+print(f"GPU: {torch.cuda.get_device_name()}")
+print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
+print("-" * 50)
+
+for E, T, H in configs:
+ try:
+ time_ms, gflops, gbps = benchmark(E, T, H)
+ print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
+ except Exception:
+ print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
index b3f81715461b1..72b54b40a2d1e 100644
--- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
@@ -110,7 +110,7 @@ def benchmark_decode(
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
- use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
+ use_tensor_cores=True,
)
wrapper.plan(
kv_indptr,
diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake
index e0da46e2accaa..cc38cd41a5b24 100644
--- a/cmake/cpu_extension.cmake
+++ b/cmake/cpu_extension.cmake
@@ -182,17 +182,17 @@ endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
# Flag to enable ACL kernels for AARCH64 platforms
-if ( VLLM_BUILD_ACL STREQUAL "ON")
+if (VLLM_BUILD_ACL STREQUAL "ON")
set(USE_ACL ON)
else()
set(USE_ACL OFF)
endif()
-if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
+if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
- GIT_TAG v3.8.1
+ GIT_TAG v3.9
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
@@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
endif()
set(ONEDNN_AARCH64_USE_ACL "ON")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
- endif()
+ endif()
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
@@ -217,38 +217,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
+ set(ONEDNN_VERBOSE "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
FetchContent_MakeAvailable(oneDNN)
-
- list(APPEND LIBS dnnl)
-elseif(POWER10_FOUND)
- FetchContent_Declare(
- oneDNN
- GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
- GIT_TAG v3.7.2
- GIT_PROGRESS TRUE
- GIT_SHALLOW TRUE
+ add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
+ target_include_directories(
+ dnnl_ext
+ PUBLIC ${oneDNN_SOURCE_DIR}/include
+ PUBLIC ${oneDNN_BINARY_DIR}/include
+ PRIVATE ${oneDNN_SOURCE_DIR}/src
)
-
- set(ONEDNN_LIBRARY_TYPE "STATIC")
- set(ONEDNN_BUILD_DOC "OFF")
- set(ONEDNN_BUILD_EXAMPLES "OFF")
- set(ONEDNN_BUILD_TESTS "OFF")
- set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
- set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
- set(ONEDNN_BUILD_GRAPH "OFF")
- set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
- set(ONEDNN_ENABLE_ITT_TASKS "OFF")
- set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
- set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
- set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
-
- set(DNNL_CPU_RUNTIME "OMP")
-
- FetchContent_MakeAvailable(oneDNN)
-
- list(APPEND LIBS dnnl)
+ target_link_libraries(dnnl_ext dnnl)
+ target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
+ list(APPEND LIBS dnnl_ext)
+ set(USE_ONEDNN ON)
+else()
+ set(USE_ONEDNN OFF)
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
@@ -275,7 +260,6 @@ set(VLLM_EXT_SRC
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp"
${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
@@ -289,14 +273,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
${VLLM_EXT_SRC})
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
endif()
-elseif(POWER10_FOUND)
- set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
- ${VLLM_EXT_SRC})
endif()
-if (ASIMD_FOUND)
+
+if(USE_ONEDNN)
set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
+ "csrc/cpu/dnnl_kernels.cpp"
${VLLM_EXT_SRC})
endif()
diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
index e0e95d06290df..6dd6f269f3dc9 100644
--- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
+++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
@@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
- num_kv_splits, // split_kv
+ static_cast(num_kv_splits), // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
@@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
- arguments.split_kv = num_kv_splits;
+ arguments.split_kv = static_cast(num_kv_splits);
MlaSm100Type::Fmha::set_split_kv(arguments);
return MlaSm100Type::Fmha::get_workspace_size(arguments);
diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp
index 3952c43cbc727..982f7c07a13bd 100644
--- a/csrc/cpu/cpu_types_x86.hpp
+++ b/csrc/cpu/cpu_types_x86.hpp
@@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec {
explicit FP16Vec16(const FP32Vec16&);
- void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
+ void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
@@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec {
explicit BF16Vec16(const FP32Vec16&);
- void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
+ void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
@@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec {
(__m128i)vec8_data.reg, 1)) {}
void save(void* ptr) const {
- *reinterpret_cast<__m256i*>(ptr) = reg_low;
- *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
+ _mm256_storeu_si256((__m256i*)ptr, reg_low);
+ _mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
}
};
#endif
diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp
new file mode 100644
index 0000000000000..f3f00edb36068
--- /dev/null
+++ b/csrc/cpu/dnnl_helper.cpp
@@ -0,0 +1,346 @@
+#include
+#include
+
+#include "common/memory_desc.hpp"
+#include "common/memory.hpp"
+
+#include "dnnl_helper.h"
+
+static dnnl::engine& default_engine() {
+ static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
+ return engine;
+}
+
+static dnnl::stream& default_stream() {
+ static dnnl::stream stream(default_engine());
+ return stream;
+}
+
+void release_dnnl_matmul_handler(int64_t handler) {
+ DNNLMatMulPrimitiveHandler* ptr =
+ reinterpret_cast(handler);
+ delete ptr;
+}
+
+template
+class DNNLPrimitiveCache {
+ public:
+ using cache_value_t = std::pair;
+ using result_value_t = VT;
+ using container_t = std::list;
+ using value_iterator_t = typename container_t::iterator;
+ using map_t = std::unordered_map;
+ using creator_t = VT (*)();
+
+ public:
+ DNNLPrimitiveCache(size_t capacity)
+ : capacity_(capacity),
+ values_(),
+ key_to_value_(std::min(256lu, capacity)) {
+ assert(capacity > 0);
+ }
+
+ template
+ result_value_t get_or_create(const KT& key, F&& creator) {
+ std::optional value = get_value(key);
+ if (value.has_value()) {
+ return value.value()->second;
+ } else {
+ return add_value({key, creator()})->second;
+ }
+ }
+
+ size_t size() const { return values_.size(); }
+
+ private:
+ void dump_data() {
+ std::stringstream ss;
+ ss << "table_id: " << std::hex << reinterpret_cast(this) << std::dec
+ << "\n";
+ ss << "container: [";
+ for (auto&& iter : values_) {
+ ss << "(" << iter.first << ", " << std::hex
+ << reinterpret_cast(iter.second.get()) << "), " << std::dec;
+ }
+ ss << "]\n";
+
+ ss << "map: [";
+ for (auto&& iter : key_to_value_) {
+ ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
+ << reinterpret_cast(iter.second->second.get()) << std::dec
+ << "), ";
+ }
+ ss << "]\n";
+ std::printf("%s\n", ss.str().c_str());
+ }
+
+ value_iterator_t add_value(cache_value_t&& new_value) {
+ if (size() == capacity_) {
+ cache_value_t& last_item = values_.back();
+ key_to_value_.erase(last_item.first);
+ values_.pop_back();
+ }
+
+ auto& added_value_ = values_.emplace_front(std::move(new_value));
+ key_to_value_.emplace(added_value_.first, values_.begin());
+ return values_.begin();
+ }
+
+ std::optional get_value(const KT& key) {
+ if (key_to_value_.size() > 0 && key == values_.begin()->first) {
+ return values_.begin();
+ }
+
+ auto value_map_iterator = key_to_value_.find(key);
+ if (value_map_iterator != key_to_value_.end()) {
+ values_.splice(values_.begin(), values_, value_map_iterator->second);
+ return value_map_iterator->second;
+ } else {
+ return {};
+ }
+ }
+
+ private:
+ const size_t capacity_;
+ container_t values_;
+ map_t key_to_value_;
+};
+
+DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
+ const Args& args, dnnl::memory::data_type b_type)
+ : b_n_size_(args.b_n_size),
+ b_n_stride_(args.b_n_stride),
+ b_k_size_(args.b_k_size),
+ b_k_stride_(args.b_k_stride),
+ b_type_(b_type),
+ c_type_(args.c_type),
+ runtime_memory_ptrs_(8),
+ primitive_cache_size_(args.primitive_cache_size) {
+ assert(primitive_cache_size_ > 0);
+}
+
+void DNNLMatMulPrimitiveHandler::prepack_weight(
+ void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
+ dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
+ {b_k_stride_, b_n_stride_});
+ dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
+ dnnl::memory packed_weight(b_target_mem_desc, default_engine());
+ {
+ dnnl::reorder(original_weight, packed_weight)
+ .execute(default_stream(), original_weight, packed_weight);
+ default_stream().wait();
+ }
+ memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
+ b_target_mem_desc_ = b_target_mem_desc;
+}
+
+void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
+ size_t index, dnnl_memory* memory_ptr) {
+ dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
+ dnnl_memory_desc* mem_desc = const_cast(memory_ptr->md());
+ runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
+}
+
+std::pair
+DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
+ return runtime_memory_ptrs_[index];
+}
+
+namespace std {
+template <>
+struct hash {
+ size_t operator()(
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
+ return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^
+ hash()(static_cast(val.a_qs)) ^
+ hash()(static_cast(val.b_qs)) ^ hash()(val.use_azp) ^
+ hash()(static_cast(val.c_type));
+ }
+};
+
+template <>
+struct hash {
+ size_t operator()(
+ const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
+ return hash()(val.a_m_size) ^ hash()(val.use_bias) ^
+ hash()(static_cast(val.bias_type));
+ }
+};
+} // namespace std
+
+bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
+ return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
+ l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
+ l.c_type == r.c_type;
+}
+
+bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
+ const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
+ return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
+ l.bias_type == r.bias_type;
+}
+
+static std::shared_ptr
+get_w8a8_class_primitive_cache(
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
+ int64_t cache_size) {
+ static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
+ assert(cache_size > 0);
+ return cache.get_or_create(key, [&]() {
+ return std::make_shared(cache_size);
+ });
+}
+
+W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
+ : DNNLMatMulPrimitiveHandler(
+ static_cast(args),
+ dnnl::memory::data_type::s8),
+ use_azp_(args.use_a_zero_point),
+ a_qs_(args.a_quantization_strategy),
+ b_qs_(args.b_quantization_strategy),
+ m_size_cache_(nullptr) {
+ assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
+ assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
+ if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
+ assert(!use_azp_);
+ };
+ prepack_weight(args.b_ptr,
+ create_primitive_desc(
+ MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
+ .use_bias = false,
+ .bias_type = dnnl::memory::data_type::undef},
+ true)
+ .weights_desc());
+ init_runtime_memory_cache(args);
+}
+
+void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
+ auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
+ auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
+ a_storage->set_data_handle((void*)args.a_ptr);
+ a_mem_desc->dims[0] = args.a_m_size;
+ c_storage->set_data_handle((void*)args.c_ptr);
+ c_mem_desc->dims[0] = args.a_m_size;
+
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
+ a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
+ }
+ if (use_azp_) {
+ auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
+ get_runtime_memory_ptr(3);
+ a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
+ }
+
+ if (args.use_bias) {
+ auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
+ bias_storage->set_data_handle((void*)args.bias_ptr);
+ }
+
+ dnnl::matmul matmul = get_matmul_cache(args);
+ matmul.execute(default_stream(), memory_cache_);
+ default_stream().wait();
+}
+
+dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
+ const MSizeCacheKey& key) {
+ if (m_size_cache_.get() == nullptr) {
+ ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
+ .b_k_size = b_k_size_,
+ .a_qs = a_qs_,
+ .b_qs = b_qs_,
+ .use_azp = use_azp_,
+ .c_type = c_type_};
+ m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
+ }
+
+ return m_size_cache_->get_or_create(key, [&]() {
+ dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
+ return dnnl::matmul(desc);
+ });
+}
+
+void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
+ memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
+ dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::ab},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
+ memory_cache_[DNNL_ARG_DST] =
+ dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
+
+ // For PER_TOKEN, scales will be applied in outside epilogue
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
+ {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
+ set_runtime_memory_ptr(
+ 2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
+ if (use_azp_) {
+ memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
+ {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
+ set_runtime_memory_ptr(
+ 3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
+ }
+ }
+
+ if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
+ dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
+ (void*)args.b_scales_ptr);
+ } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
+ dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
+ default_engine(), (void*)args.b_scales_ptr);
+ }
+
+ memory_cache_[DNNL_ARG_BIAS] =
+ dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
+}
+
+dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
+ const MSizeCacheKey& key, bool first_time) {
+ dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
+ dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::ab);
+ dnnl::memory::desc b_md;
+ if (first_time) {
+ b_md =
+ dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::any);
+ } else {
+ b_md = b_target_mem_desc_;
+ }
+ dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
+ dnnl::memory::format_tag::ab);
+
+ dnnl::primitive_attr attr;
+ // For PER_TOKEN, scales will be applied in outside epilogue
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ attr.set_scales_mask(DNNL_ARG_SRC, 0);
+ if (use_azp_) {
+ attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
+ }
+ }
+
+ if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
+ attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
+ } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
+ attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
+ }
+
+ if (key.use_bias) {
+ // For PER_TOKEN, bias will be applied in epilogue
+ assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
+ dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
+ return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
+ c_md, attr);
+ } else {
+ return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
+ attr);
+ }
+}
diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h
new file mode 100644
index 0000000000000..54ceefced9e98
--- /dev/null
+++ b/csrc/cpu/dnnl_helper.h
@@ -0,0 +1,169 @@
+#ifndef DNNL_HELPER_H
+#define DNNL_HELPER_H
+
+#include
+#include
+
+#include "oneapi/dnnl/dnnl.hpp"
+
+namespace c10 {
+struct BFloat16;
+struct Half;
+} // namespace c10
+
+namespace dnnl {
+namespace impl {
+struct memory_storage_t;
+struct matmul_pd_t;
+struct matmul_desc_t;
+} // namespace impl
+} // namespace dnnl
+struct dnnl_memory_desc;
+
+template
+class DNNLPrimitiveCache;
+
+template
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type =
+ dnnl::memory::data_type::undef;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
+};
+
+template
+constexpr inline dnnl::memory::data_type get_dnnl_type() {
+ return DNNLType>::type;
+}
+
+class DNNLMatMulPrimitiveHandler {
+ public:
+ virtual ~DNNLMatMulPrimitiveHandler() = default;
+
+ protected:
+ struct Args {
+ dnnl_dim_t b_n_size;
+ dnnl_dim_t b_n_stride;
+ dnnl_dim_t b_k_size;
+ dnnl_dim_t b_k_stride;
+ void* b_ptr;
+ dnnl::memory::data_type c_type;
+ size_t primitive_cache_size;
+ };
+
+ protected:
+ DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
+
+ void prepack_weight(void* original_b_ptr,
+ dnnl::memory::desc b_target_mem_desc);
+
+ void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
+
+ std::pair
+ get_runtime_memory_ptr(size_t index);
+
+ protected:
+ const dnnl_dim_t b_n_size_;
+ const dnnl_dim_t b_n_stride_;
+ const dnnl_dim_t b_k_size_;
+ const dnnl_dim_t b_k_stride_;
+ dnnl::memory::data_type b_type_;
+ dnnl::memory::data_type c_type_;
+ std::unordered_map memory_cache_;
+ std::vector>
+ runtime_memory_ptrs_;
+ dnnl::memory::desc b_target_mem_desc_;
+ int64_t primitive_cache_size_;
+};
+
+class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
+ public:
+ enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL };
+
+ struct Args : public DNNLMatMulPrimitiveHandler::Args {
+ bool use_a_zero_point;
+ QuantizationStrategy a_quantization_strategy;
+ QuantizationStrategy b_quantization_strategy;
+ float* b_scales_ptr;
+ };
+
+ struct ClassMatmulCacheKey {
+ dnnl_dim_t b_n_size;
+ dnnl_dim_t b_k_size;
+ QuantizationStrategy a_qs;
+ QuantizationStrategy b_qs;
+ bool use_azp;
+ dnnl::memory::data_type c_type;
+
+ friend bool operator==(const ClassMatmulCacheKey& l,
+ const ClassMatmulCacheKey& r);
+ };
+
+ struct MSizeCacheKey {
+ dnnl_dim_t a_m_size;
+ bool use_bias;
+ dnnl::memory::data_type bias_type;
+
+ friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
+ };
+
+ using MSizeCache = DNNLPrimitiveCache;
+ using ClassMatmulCache =
+ DNNLPrimitiveCache>;
+
+ struct ExecArgs : public MSizeCacheKey {
+ const int8_t* a_ptr;
+ const float* a_scales_ptr;
+ const int32_t* a_zero_points_ptr;
+ const void* bias_ptr;
+ void* c_ptr;
+ };
+
+ public:
+ W8A8MatMulPrimitiveHandler(const Args& args);
+
+ QuantizationStrategy get_input_scale_strategy() const { return a_qs_; }
+
+ bool get_input_use_zero_point() const { return use_azp_; }
+
+ void execute(ExecArgs& args);
+
+ private:
+ dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
+ bool first_time);
+
+ void init_runtime_memory_cache(const Args& args);
+
+ dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
+
+ private:
+ const bool use_azp_;
+ const QuantizationStrategy a_qs_;
+ const QuantizationStrategy b_qs_;
+ std::shared_ptr m_size_cache_;
+};
+
+#endif
diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp
deleted file mode 100644
index 1cb8dc5b25a66..0000000000000
--- a/csrc/cpu/dnnl_helper.hpp
+++ /dev/null
@@ -1,206 +0,0 @@
-#ifndef DNNL_HELPER_HPP
-#define DNNL_HELPER_HPP
-
-#include
-#include
-
-#include "oneapi/dnnl/dnnl.hpp"
-
-namespace {
-template
-struct DNNLType {
- static constexpr dnnl::memory::data_type type =
- dnnl::memory::data_type::undef;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
-};
-
-template
-constexpr inline dnnl::memory::data_type get_dnnl_type() {
- return DNNLType>::type;
-}
-}; // namespace
-
-template
-class DNNLPrimitiveHelper {
- public:
- // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
- // A: [M, K], row-major
- // B: [K, N], column-major
- // C: [M, N], row-major
- // bias: [N], row-major, optional
- // a_scales: [MS]
- // b_scales: [NS]
- // Note: Due to the limitation of oneDNN
- // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
- // not supported.
-
- template
- static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
- const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
- dnnl_dim_t K, const float* a_scales,
- const float* b_scales, dnnl_dim_t MS,
- dnnl_dim_t NS) {
- auto&& OutputType = get_dnnl_type();
- auto&& BiasType = get_dnnl_type();
-
- dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
- dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
- dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
-
- dnnl::primitive_attr attr;
- if constexpr (!InputNoScale) {
- if (MS == 1) {
- // per-tensor
- attr.set_scales_mask(DNNL_ARG_SRC, 0);
- } else {
- // per-token
- TORCH_CHECK(false, "per-token quantization is unsupported.");
- }
- }
-
- if (NS == 1) {
- // per-tensor
- attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
- } else {
- // per-channel
- attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
- }
-
- dnnl::matmul::primitive_desc matmul_pd;
-// Create memory descriptors with format_tag::any for the primitive. This
-// enables the matmul primitive to choose memory layouts for an
-// optimized primitive implementation, and these layouts may differ from the
-// ones provided by the user.
-#ifdef __aarch64__
- auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
- dnnl::memory::format_tag::any);
- auto mat_weights_md = dnnl::memory::desc(
- {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
- auto mat_dst_md =
- dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
- if (bias) {
- dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
- mat_weights_md, bias_md,
- mat_dst_md, attr);
- } else {
- matmul_pd = dnnl::matmul::primitive_desc(
- default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
- }
-#else
- if (bias) {
- dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
- bias_md, c_md, attr);
- } else {
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
- c_md, attr);
- }
-#endif
- dnnl::matmul matmul(matmul_pd);
-
- auto& engine = default_engine();
-
- dnnl::memory a_m(a_md, engine, (void*)a);
- dnnl::memory b_m(b_md, engine, (void*)b);
- dnnl::memory c_m(c_md, engine, (void*)c);
- dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
- (void*)a_scales);
- dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
- (void*)b_scales);
-
- auto& stream = default_stream();
-
- auto mat_src_mem = a_m;
- auto mat_weights_mem = b_m;
- auto mat_dst_mem = c_m;
-#ifdef __aarch64__
- if (matmul_pd.weights_desc() != b_m.get_desc()) {
- mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
- dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
- }
-#endif
- if constexpr (InputNoScale) {
- if (bias) {
- dnnl::memory::desc bias_md({N}, BiasType, {1});
- dnnl::memory bias_m(bias_md, engine, (void*)bias);
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_BIAS, bias_m},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- } else {
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- }
- } else {
- if (bias) {
- dnnl::memory::desc bias_md({N}, BiasType, {1});
- dnnl::memory bias_m(bias_md, engine, (void*)bias);
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_BIAS, bias_m},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- } else {
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- }
- }
- stream.wait();
- }
-
- private:
- static dnnl::engine& default_engine() {
- static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
- return engine;
- }
-
- static dnnl::stream& default_stream() {
- static dnnl::stream stream(default_engine());
- return stream;
- }
-};
-#endif
diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp
new file mode 100644
index 0000000000000..acc3b9ecde143
--- /dev/null
+++ b/csrc/cpu/dnnl_kernels.cpp
@@ -0,0 +1,494 @@
+#include "cpu_types.hpp"
+#include "dnnl_helper.h"
+
+namespace {
+template
+struct KernelVecType {
+ using load_vec_type = void;
+ using cvt_vec_type = void;
+};
+
+template <>
+struct KernelVecType {
+ using load_vec_type = vec_op::FP32Vec16;
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+
+#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
+template <>
+struct KernelVecType {
+ using load_vec_type = vec_op::BF16Vec16;
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+#endif
+
+template <>
+struct KernelVecType {
+#if defined(__powerpc64__) || defined(__s390x__)
+ // Power architecture-specific vector type
+ using load_vec_type = vec_op::FP32Vec16;
+#else
+ // Fallback for other architectures
+ using load_vec_type = vec_op::FP16Vec16;
+#endif
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+
+template
+void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+ const float* scale, const int32_t* azp,
+ const int64_t num_tokens,
+ const int64_t input_stride,
+ const int64_t hidden_size) {
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ constexpr float i8_min =
+ static_cast(std::numeric_limits::min());
+ constexpr float i8_max =
+ static_cast(std::numeric_limits::max());
+ const cvt_vec_t inv_scale(1.0 / *scale);
+ const cvt_vec_t i8_min_vec(i8_min);
+ const cvt_vec_t i8_max_vec(i8_max);
+
+ cvt_vec_t zp_vec;
+ if constexpr (AZP) {
+ zp_vec = cvt_vec_t(static_cast(*azp));
+ }
+
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ int8_t* output_ptr = output + i * hidden_size;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = elems_fp32 * inv_scale;
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + zp_vec;
+ }
+
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j);
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = elems_fp32 * inv_scale;
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + zp_vec;
+ }
+
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j, hidden_size - j);
+ }
+}
+
+template
+void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+ float* scale, int32_t* azp,
+ const int64_t num_tokens,
+ const int64_t input_stride,
+ const int64_t hidden_size) {
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ constexpr float i8_min =
+ static_cast(std::numeric_limits::min());
+ constexpr float i8_max =
+ static_cast(std::numeric_limits::max());
+ const cvt_vec_t i8_min_vec(i8_min);
+ const cvt_vec_t i8_max_vec(i8_max);
+
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ cvt_vec_t max_value(std::numeric_limits::lowest());
+ cvt_vec_t min_value(std::numeric_limits::max());
+ {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32);
+ min_value = min_value.min(elems_fp32);
+ } else {
+ max_value = max_value.max(elems_fp32.abs());
+ }
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+
+ if (j + vec_elem_num == hidden_size) {
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32);
+ min_value = min_value.min(elems_fp32);
+ } else {
+ max_value = max_value.max(elems_fp32.abs());
+ }
+ } else {
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32, hidden_size - j);
+ min_value = min_value.min(elems_fp32, hidden_size - j);
+ } else {
+ max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
+ }
+ }
+ }
+
+ float scale_val, azp_val;
+ if constexpr (AZP) {
+ float max_scalar = max_value.reduce_max();
+ float min_scalar = min_value.reduce_min();
+ scale_val = (max_scalar - min_scalar) / 255.0f;
+ azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
+ azp[i] = azp_val;
+ scale[i] = scale_val;
+ } else {
+ scale_val = max_value.reduce_max() / 127.0f;
+ scale[i] = scale_val;
+ }
+
+ const cvt_vec_t inv_scale(1.0 / scale_val);
+ const cvt_vec_t azp_vec(azp_val);
+
+ {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ int8_t* output_ptr = output + i * hidden_size;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = (elems_fp32 * inv_scale);
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + azp_vec;
+ }
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j);
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = (elems_fp32 * inv_scale);
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + azp_vec;
+ }
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j, hidden_size - j);
+ }
+ }
+}
+
+template
+void dynamic_quant_epilogue(const float* input, scalar_t* output,
+ const float* a_scale, const int32_t* azp,
+ const float* azp_adj, const scalar_t* bias,
+ const int64_t num_tokens,
+ const int64_t hidden_size) {
+ CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ const int64_t thread_num = omp_get_max_threads();
+ if (num_tokens > thread_num) {
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ const float* input_ptr = input + i * hidden_size;
+ scalar_t* output_ptr = output + i * hidden_size;
+ int64_t j = 0;
+ cvt_vec_t token_scale_vec(a_scale[i]);
+ cvt_vec_t token_zp_scale_vec;
+ if constexpr (AZP) {
+ float zp_scale_val = a_scale[i] * static_cast(azp[i]);
+ token_zp_scale_vec = cvt_vec_t(zp_scale_val);
+ }
+ for (; j < hidden_size - vec_elem_num; ++j) {
+ cvt_vec_t elems_fp32(input_ptr + j);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + j);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + j);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + j);
+ }
+ cvt_vec_t elems_fp32(input_ptr + j);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + j);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + j);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + j, hidden_size - j);
+ }
+ } else {
+ const int64_t vec_iteration =
+ (hidden_size + vec_elem_num - 1) / vec_elem_num;
+ const int64_t vec_iteration_per_thread =
+ (vec_iteration + thread_num - 1) / thread_num;
+ const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num;
+#pragma omp parallel for schedule(static, 1)
+ for (int64_t i = 0; i < thread_num; ++i) {
+ const int64_t start = elem_num_per_thread * i;
+ const int64_t end = std::min(hidden_size, elem_num_per_thread + start);
+ for (int64_t j = 0; j < num_tokens; ++j) {
+ cvt_vec_t token_scale_vec(a_scale[j]);
+ cvt_vec_t token_zp_scale_vec;
+ if constexpr (AZP) {
+ float zp_scale_val = a_scale[j] * static_cast(azp[j]);
+ token_zp_scale_vec = cvt_vec_t(zp_scale_val);
+ }
+ int64_t k = start;
+ const float* input_ptr = input + j * hidden_size;
+ scalar_t* output_ptr = output + j * hidden_size;
+ for (; k < end - vec_elem_num; k += vec_elem_num) {
+ cvt_vec_t elems_fp32(input_ptr + k);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + k);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + k);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + k);
+ }
+ if (k < end) {
+ cvt_vec_t elems_fp32(input_ptr + k);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + k);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + k);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + k, end - k);
+ }
+ }
+ }
+ }
+}
+} // namespace
+
+int64_t create_onednn_scaled_mm_handler(
+ const torch::Tensor& b, // [IC, OC], column-major
+ const torch::Tensor& b_scales, // [1] or [OC]
+ at::ScalarType output_type, bool dynamic_act_quant, bool use_azp,
+ int64_t primitive_cache_size) {
+ TORCH_CHECK(b.dim() == 2);
+ TORCH_CHECK(b.stride(0) == 1); // Column-major
+ TORCH_CHECK(b_scales.is_contiguous());
+
+ W8A8MatMulPrimitiveHandler::Args args;
+ args.primitive_cache_size = primitive_cache_size;
+
+ if (b_scales.numel() == 1) {
+ args.b_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
+ } else {
+ TORCH_CHECK_EQ(b_scales.numel(), b.size(1));
+ args.b_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL;
+ }
+ args.b_scales_ptr = b_scales.data_ptr();
+ args.b_k_size = b.size(0);
+ args.b_k_stride = b.stride(0);
+ args.b_n_size = b.size(1);
+ args.b_n_stride = b.stride(1);
+ args.b_ptr = b.data_ptr();
+
+ if (dynamic_act_quant) {
+ // dynamic per-token, bias, A scales and A zps will be applied in outside.
+ args.a_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN;
+ args.use_a_zero_point = false;
+ } else {
+ // static per-tensor
+ args.a_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
+ args.use_a_zero_point = use_azp;
+ }
+
+ VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler",
+ [&] {
+ if (dynamic_act_quant) {
+ args.c_type = get_dnnl_type();
+ } else {
+ args.c_type = get_dnnl_type();
+ }
+ });
+
+ return reinterpret_cast(new W8A8MatMulPrimitiveHandler(args));
+}
+
+void onednn_scaled_mm(
+ torch::Tensor& c, // [M, OC], row-major
+ const torch::Tensor& a, // [M, IC], row-major
+ const torch::Tensor& a_scales, // [M] or [1]
+ const std::optional& azp, // [M] or [1]
+ const std::optional& azp_adj, // [M] or [1]
+ const std::optional& bias, // [N]
+ int64_t handler) {
+ CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
+ TORCH_CHECK(a.dim() == 2);
+ TORCH_CHECK(a.is_contiguous());
+ TORCH_CHECK(c.is_contiguous());
+ W8A8MatMulPrimitiveHandler* ptr =
+ reinterpret_cast(handler);
+ const int32_t* azp_ptr = nullptr;
+ if (azp.has_value()) {
+ azp_ptr = azp->data_ptr();
+ }
+ if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
+ TORCH_CHECK_EQ(a_scales.numel(), 1);
+ }
+
+ W8A8MatMulPrimitiveHandler::ExecArgs exec_args;
+ exec_args.a_ptr = a.data_ptr();
+ exec_args.a_m_size = a.size(0);
+ exec_args.bias_ptr = nullptr;
+ exec_args.use_bias = false;
+ exec_args.a_scales_ptr = nullptr;
+ exec_args.a_zero_points_ptr = nullptr;
+
+ VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] {
+ if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
+ if (bias.has_value()) {
+ exec_args.bias_ptr = bias->data_ptr();
+ exec_args.bias_type = get_dnnl_type();
+ exec_args.use_bias = true;
+ }
+ exec_args.a_scales_ptr = a_scales.data_ptr();
+ exec_args.a_zero_points_ptr = azp_ptr;
+ exec_args.c_ptr = c.data_ptr();
+ ptr->execute(exec_args);
+ } else if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) {
+ torch::Tensor tmp_fp32_out =
+ torch::empty_like(c, ::at::ScalarType::Float);
+ exec_args.c_ptr = tmp_fp32_out.data_ptr();
+ ptr->execute(exec_args);
+ if (bias.has_value()) {
+ if (azp.has_value()) {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(),
+ bias->data_ptr(), c.size(0), c.size(1));
+ } else {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, nullptr,
+ bias->data_ptr(), c.size(0), c.size(1));
+ }
+ } else {
+ if (azp.has_value()) {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(),
+ (scalar_t*)nullptr, c.size(0), c.size(1));
+ } else {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, nullptr, (scalar_t*)nullptr,
+ c.size(0), c.size(1));
+ }
+ }
+ } else {
+ TORCH_CHECK(false, "invalid act quant type.");
+ }
+ });
+}
+
+// static-per-tensor quantization.
+void static_scaled_int8_quant(
+ torch::Tensor& out, // [batch, hidden_size]
+ const torch::Tensor& input, // [batch, hidden_size]
+ const torch::Tensor& scale, std::optional const& azp) {
+ CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
+ TORCH_CHECK(out.is_contiguous());
+ TORCH_CHECK_EQ(input.dim(), 2);
+ TORCH_CHECK_EQ(input.stride(1), 1);
+ TORCH_CHECK(scale.numel() == 1);
+ TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
+
+ const int64_t stride = input.stride(0);
+ const int64_t hidden_size = input.size(1);
+ const int64_t num_tokens = input.size(0);
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
+ if (azp.has_value()) {
+ static_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), azp->data_ptr(), num_tokens,
+ stride, hidden_size);
+ } else {
+ static_scaled_int8_quant_impl(input.data_ptr(),
+ out.data_ptr(),
+ scale.data_ptr(), nullptr,
+ num_tokens, stride, hidden_size);
+ }
+ });
+}
+
+// dynamic-per-token quantization.
+void dynamic_scaled_int8_quant(
+ torch::Tensor& out, // [batch, hidden_size]
+ const torch::Tensor& input, // [batch, hidden_size]
+ torch::Tensor& scale, // [batch, 1]
+ std::optional const& azp) {
+ CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
+ TORCH_CHECK(out.is_contiguous());
+ TORCH_CHECK_EQ(input.dim(), 2);
+ TORCH_CHECK_EQ(input.stride(1), 1);
+
+ const int64_t hidden_size = input.size(1);
+ const int64_t num_tokens = input.size(0);
+ const int64_t stride = input.stride(0);
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
+ if (azp.has_value()) {
+ dynamic_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), azp->data_ptr(), num_tokens,
+ stride, hidden_size);
+ } else {
+ dynamic_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), nullptr, num_tokens, stride,
+ hidden_size);
+ }
+ });
+}
diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp
deleted file mode 100644
index 6e120b8d20a7e..0000000000000
--- a/csrc/cpu/quant.cpp
+++ /dev/null
@@ -1,951 +0,0 @@
-#include "cpu_types.hpp"
-#include "dnnl_helper.hpp"
-
-namespace {
-template
-struct KernelVecType {
- using load_vec_type = void;
- using azp_adj_load_vec_type = void;
- using cvt_vec_type = void;
-};
-
-template <>
-struct KernelVecType {
- using load_vec_type = vec_op::FP32Vec16;
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-
-#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
-template <>
-struct KernelVecType {
- using load_vec_type = vec_op::BF16Vec16;
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-#endif
-
-template <>
-struct KernelVecType {
-#if defined(__powerpc64__) || defined(__s390x__)
- // Power architecture-specific vector type
- using load_vec_type = vec_op::FP32Vec16;
-#else
- // Fallback for other architectures
- using load_vec_type = vec_op::FP16Vec16;
-#endif
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-
-#if defined(__AVX512F__) || defined(__aarch64__)
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t inv_scale(1.0 / *scale);
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- cvt_vec_t zp_vec;
- if constexpr (AZP) {
- zp_vec = cvt_vec_t(static_cast(*azp));
- }
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t max_value(std::numeric_limits::lowest());
- cvt_vec_t min_value(std::numeric_limits::max());
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
-
- if (j + vec_elem_num == hidden_size) {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- } else {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32, hidden_size - j);
- min_value = min_value.min(elems_fp32, hidden_size - j);
- } else {
- max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
- }
- }
- }
-
- float scale_val, azp_val;
- if constexpr (AZP) {
- float max_scalar = max_value.reduce_max();
- float min_scalar = min_value.reduce_min();
- scale_val = (max_scalar - min_scalar) / 255.0f;
- azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
- azp[i] = static_cast(azp_val);
- scale[i] = scale_val;
- } else {
- scale_val = max_value.reduce_max() / 127.0f;
- scale[i] = scale_val;
- }
-
- const cvt_vec_t inv_scale(1.0 / scale_val);
- const cvt_vec_t azp_vec(azp_val);
-
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
- }
-}
-
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t a_scale_vec(a_scale);
- cvt_vec_t b_scale_vec(*b_scale);
- cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
-
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- cvt_vec_t token_scale_vec(a_scale[i]);
- cvt_vec_t token_zp_scale_vec;
- if constexpr (AZP) {
- float zp_scale_val = a_scale[i] * static_cast(azp[i]);
- if constexpr (!PerChannel) {
- zp_scale_val *= *b_scale;
- }
- token_zp_scale_vec = cvt_vec_t(zp_scale_val);
- }
-
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-#elif defined(__powerpc64__)
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
-
- const cvt_vec_t inv_scale(1.0 / *scale);
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- cvt_vec_t zp_vec;
- if constexpr (AZP) {
- zp_vec = cvt_vec_t(static_cast(*azp));
- }
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t max_value(std::numeric_limits::lowest());
- cvt_vec_t min_value(std::numeric_limits::max());
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
-
- if (j + vec_elem_num == hidden_size) {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- } else {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32, hidden_size - j);
- min_value = min_value.min(elems_fp32, hidden_size - j);
- } else {
- max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
- }
- }
- }
-
- float scale_val, azp_val;
- if constexpr (AZP) {
- float max_scalar = max_value.reduce_max();
- float min_scalar = min_value.reduce_min();
- scale_val = (max_scalar - min_scalar) / 255.0f;
- azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
- azp[i] = static_cast(azp_val);
- scale[i] = scale_val;
- } else {
- scale_val = max_value.reduce_max() / 127.0f;
- scale[i] = scale_val;
- }
-
- const cvt_vec_t inv_scale(1.0 / scale_val);
- const cvt_vec_t azp_vec(azp_val);
-
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
- }
-}
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t a_scale_vec(a_scale);
- cvt_vec_t b_scale_vec(*b_scale);
- cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
-
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- cvt_vec_t token_scale_vec(a_scale[i]);
- cvt_vec_t token_zp_scale_vec;
- if constexpr (AZP) {
- float zp_scale_val = a_scale[i] * static_cast(azp[i]);
- if constexpr (!PerChannel) {
- zp_scale_val *= *b_scale;
- }
- token_zp_scale_vec = cvt_vec_t(zp_scale_val);
- }
-
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-#else
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(false,
- "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
- "support.")
-}
-
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(false,
- "dynamic_scaled_int8_quant_impl requires "
- "AVX512/powerpc64/AArch64 support.")
-}
-
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(
- false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
-}
-
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_with_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(
- false,
- "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
-}
-#endif
-} // namespace
-
-void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
- const torch::Tensor& a, // [M, IC], row-major
- const torch::Tensor& b, // [IC, OC], column-major
- const torch::Tensor& a_scales, // [1] or [M]
- const torch::Tensor& b_scales, // [1] or [OC]
- const std::optional& bias // [OC]
-) {
- CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
- // Checks for conformality
- TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
- "int8_scaled_mm only supports INT8 inputs.")
- TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
- TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
- b.size(1) == c.size(1));
- TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
- TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
-
- // Check for strides and alignment
- TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
- TORCH_CHECK(b.stride(0) == 1); // Column-major
- TORCH_CHECK(c.stride(0) % 16 == 0 &&
- b.stride(1) % 16 == 0); // 16 Byte Alignment
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
-
- if (bias) {
- TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
- bias->dim() == 1);
- }
-
- VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
- if (a_scales.numel() != 1) {
- // per-token
- // Note: oneDNN doesn't support per-token activation quantization
- // Ideally we want to fuse the GEMM and the scale procedure with oneDNN
- // JIT, the intermediate data is cached in registers or L1. But for now
- // the oneDNN GEMM code generation only supports two quantization
- // patterns: per-tensor or per-output-channel of weight.
- // So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
- // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
- // GEMM, then the per-token scale (and bias) is applied with the epilogue
- // C=s_a * C_inter + bias.
- torch::Tensor tmp_fp32_out =
- torch::empty_like(c, ::at::ScalarType::Float);
- // Compute C_inter=s_b * (A@B)
- DNNLPrimitiveHelper::gemm_s8s8_jit(
- a.data_ptr(), b.data_ptr(),
- tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1),
- a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel());
- if (bias.has_value()) {
- // Compute C=s_a * C_inter + bias
- dynamic_quant_epilogue(
- tmp_fp32_out.data_ptr(), c.data_ptr(),
- a_scales.data_ptr(), nullptr, nullptr, nullptr,
- bias->data_ptr(), c.size(0), c.size(1));
- } else {
- // Compute C=s_a * C_inter
- dynamic_quant_epilogue(
- tmp_fp32_out.data_ptr(), c.data_ptr(),
- a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr,
- c.size(0), c.size(1));
- }
- } else {
- // per-tensor
- if (bias.has_value()) {
- // Compute C=s_a * s_b * (A@B) + bias
- DNNLPrimitiveHelper::gemm_s8s8_jit(
- a.data_ptr(), b.data_ptr(), c.data_ptr(),
- bias->data_ptr(), a.size(0), b.size(1), a.size(1),
- a_scales.data_ptr(), b_scales.data_ptr(),
- a_scales.numel(), b_scales.numel());
- } else {
- // Compute C=s_a * s_b * (A@B)
- DNNLPrimitiveHelper::gemm_s8s8_jit(
- a.data_ptr(), b.data_ptr(), c.data_ptr(),
- nullptr, a.size(0), b.size(1), a.size(1),
- a_scales.data_ptr(), b_scales.data_ptr(),
- a_scales.numel(), b_scales.numel());
- }
- }
- });
-}
-
-void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
- const torch::Tensor& a, // [M, IC], row-major
- const torch::Tensor& b, // [IC, OC], column-major
- const torch::Tensor& a_scales, // [1] or [M]
- const torch::Tensor& b_scales, // [1] or [OC]
- const torch::Tensor& azp_adj, // [OC]
- const std::optional& azp, // [1] or [M]
- const std::optional& bias // [OC]
-) {
- CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
- // Checks for conformality
- TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
- "int8_scaled_mm_azp only supports INT8 inputs.")
- TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
- TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
- b.size(1) == c.size(1));
- TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
- TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
-
- // Check for strides and alignment
- TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
- TORCH_CHECK(b.stride(0) == 1); // Column-major
- TORCH_CHECK(c.stride(0) % 16 == 0 &&
- b.stride(1) % 16 == 0); // 16 Byte Alignment
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
-
- if (bias) {
- TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
- }
- if (azp) {
- TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
- }
- TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
-
- // azp & bias types
- TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
- TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
- TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
- "currently bias dtype must match output dtype ", c.dtype());
-
- VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] {
- torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
- if (a_scales.numel() != 1) {
- // per-token
- // Note: oneDNN doesn't support per-token activation quantization
- // Compute C_inter=s_b * (A@B)
- DNNLPrimitiveHelper::gemm_s8s8_jit(
- a.data_ptr(), b.data_ptr(),
- tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1),
- a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel());
- if (bias.has_value()) {
- // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias
- if (b_scales.numel() != 1) {
- // Per-Channel
- dynamic_quant_epilogue(
- tmp_fp32_out.data_ptr(), c.data_ptr