diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh
new file mode 100644
index 0000000000000..f187d1f181724
--- /dev/null
+++ b/.buildkite/run-cpu-test.sh
@@ -0,0 +1,14 @@
+# This script build the CPU docker image and run the offline inference inside the container.
+# It serves a sanity check for compilation and basic model usage.
+set -ex
+
+# Try building the docker image
+docker build -t cpu-test -f Dockerfile.cpu .
+
+# Setup cleanup
+remove_docker_container() { docker rm -f cpu-test || true; }
+trap remove_docker_container EXIT
+remove_docker_container
+
+# Run the image and launch offline inference
+docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index ee384c27e1d0c..27e44463a30a6 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -34,7 +34,10 @@ steps:
command: pytest -v -s engine tokenization test_sequence.py test_config.py
- label: Entrypoints Test
- command: pytest -v -s entrypoints
+ commands:
+ # these tests have to be separated, because each one will allocate all posible GPU memory
+ - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
+ - pytest -v -s entrypoints/test_server_oot_registration.py
- label: Examples Test
working_dir: "/vllm-workspace/examples"
@@ -90,7 +93,7 @@ steps:
- bash run-benchmarks.sh
- label: Documentation Build
- working_dir: "/vllm-workspace/docs"
+ working_dir: "/vllm-workspace/test_docs/docs"
no_gpu: True
commands:
- pip install -r requirements-docs.txt
diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2
index 4dde733581822..3ed23c62c005d 100644
--- a/.buildkite/test-template.j2
+++ b/.buildkite/test-template.j2
@@ -8,6 +8,9 @@ steps:
queue: amd
command: bash .buildkite/run-amd-test.sh
+ - label: "CPU Test"
+ command: bash .buildkite/run-cpu-test.sh
+
- label: ":docker: build image"
commands:
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 5211dc180798e..fc97e33c19af2 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -49,7 +49,7 @@ jobs:
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
- pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
+ pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1']
steps:
diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh
index 2578d448436d2..60a3978f9abd7 100644
--- a/.github/workflows/scripts/build.sh
+++ b/.github/workflows/scripts/build.sh
@@ -9,12 +9,13 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
# Install requirements
$python_executable -m pip install wheel packaging
-$python_executable -m pip install -r requirements.txt
+$python_executable -m pip install -r requirements-cuda.txt
# Limit the number of parallel jobs to avoid OOM
export MAX_JOBS=1
# Make sure punica is built for the release (for LoRA)
export VLLM_INSTALL_PUNICA_KERNELS=1
-
+# Make sure release wheels are built for the following architectures
+export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
# Build
$python_executable setup.py bdist_wheel --dist-dir=dist
diff --git a/.gitignore b/.gitignore
index b5195629e5cf3..b1513ef0ddb0c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -181,6 +181,7 @@ _build/
# hip files generated by PyTorch
*.hip
*_hip*
+hip_compat.h
# Benchmark dataset
*.json
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 412b9c0cd59e0..1845151181284 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21)
project(vllm_extensions LANGUAGES CXX)
+option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
+
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
+message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
@@ -16,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
# Supported AMD GPU architectures.
-set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
+set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
#
# Supported/expected torch versions for CUDA/ROCm.
@@ -28,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
-set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2")
+set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
@@ -76,6 +79,19 @@ find_package(Torch REQUIRED)
find_library(torch_python_LIBRARY torch_python PATHS
"${TORCH_INSTALL_PREFIX}/lib")
+#
+# Forward the non-CUDA device extensions to external CMake scripts.
+#
+if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
+ NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
+ if (VLLM_TARGET_DEVICE STREQUAL "cpu")
+ include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
+ else()
+ message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
+ endif()
+ return()
+endif()
+
#
# Set up GPU language and check the torch version and warn if it isn't
# what is expected.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 8db5e569b6aec..81a8db2b268b0 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -21,7 +21,6 @@ Express your support on Twitter if vLLM aids you, or simply offer your appreciat
### Build from source
```bash
-pip install -r requirements.txt
pip install -e . # This may take several minutes.
```
@@ -30,6 +29,8 @@ pip install -e . # This may take several minutes.
```bash
pip install -r requirements-dev.txt
+# linting and formatting
+bash format.sh
# Static type checking
mypy
# Unit tests
diff --git a/Dockerfile b/Dockerfile
index f975530e09407..d1d29177b0f44 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,6 +2,7 @@
# to run the OpenAI compatible server.
#################### BASE BUILD IMAGE ####################
+# prepare basic build environment
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
RUN apt-get update -y \
@@ -16,18 +17,26 @@ RUN ldconfig /usr/local/cuda-12.1/compat/
WORKDIR /workspace
# install build and runtime dependencies
-COPY requirements.txt requirements.txt
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \
- pip install -r requirements.txt
+ pip install -r requirements-cuda.txt
# install development dependencies
COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-dev.txt
+
+# cuda arch list used by torch
+# can be useful for both `dev` and `test`
+# explicitly set the list to avoid issues with torch 2.2
+# see https://github.com/pytorch/pytorch/pull/123243
+ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
+ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
#################### BASE BUILD IMAGE ####################
-#################### EXTENSION BUILD IMAGE ####################
+#################### WHEEL BUILD IMAGE ####################
FROM dev AS build
# install build dependencies
@@ -38,18 +47,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# install compiler cache to speed up compilation leveraging local or remote caching
RUN apt-get update -y && apt-get install -y ccache
-# copy input files
+# files and directories related to build wheels
COPY csrc csrc
COPY setup.py setup.py
COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt
-COPY requirements.txt requirements.txt
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-cuda.txt requirements-cuda.txt
COPY pyproject.toml pyproject.toml
-COPY vllm/__init__.py vllm/__init__.py
+COPY vllm vllm
-# cuda arch list used by torch
-ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
-ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# max jobs used by Ninja to build extensions
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
@@ -61,7 +68,15 @@ ENV VLLM_INSTALL_PUNICA_KERNELS=1
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
- python3 setup.py build_ext --inplace
+ --mount=type=cache,target=/root/.cache/pip \
+ python3 setup.py bdist_wheel --dist-dir=dist
+
+# the `vllm_nccl` package must be installed from source distribution
+# pip is too smart to store a wheel in the cache, and other CI jobs
+# will directly use the wheel from the cache, which is not what we want.
+# we need to remove it manually
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip cache remove vllm_nccl*
#################### EXTENSION Build IMAGE ####################
#################### FLASH_ATTENTION Build IMAGE ####################
@@ -81,57 +96,59 @@ RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
#################### FLASH_ATTENTION Build IMAGE ####################
+#################### vLLM installation IMAGE ####################
+# image with vLLM installed
+FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
+WORKDIR /vllm-workspace
+
+RUN apt-get update -y \
+ && apt-get install -y python3-pip git vim
+
+# Workaround for https://github.com/openai/triton/issues/2507 and
+# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
+# this won't be needed for future versions of this docker image
+# or future versions of triton.
+RUN ldconfig /usr/local/cuda-12.1/compat/
+
+# install vllm wheel first, so that torch etc will be installed
+RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
+ --mount=type=cache,target=/root/.cache/pip \
+ pip install dist/*.whl --verbose
+
+RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
+ --mount=type=cache,target=/root/.cache/pip \
+ pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
+#################### vLLM installation IMAGE ####################
+
+
#################### TEST IMAGE ####################
# image to run unit testing suite
-FROM dev AS test
+# note that this uses vllm installed by `pip`
+FROM vllm-base AS test
-# copy pytorch extensions separately to avoid having to rebuild
-# when python code changes
-WORKDIR /vllm-workspace
-# ADD is used to preserve directory structure
ADD . /vllm-workspace/
-COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
-# Install flash attention (from pre-built wheel)
-RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
- pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
-# ignore build dependencies installation because we are using pre-complied extensions
-RUN rm pyproject.toml
-RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
-#################### TEST IMAGE ####################
-
-#################### RUNTIME BASE IMAGE ####################
-# We used base cuda image because pytorch installs its own cuda libraries.
-# However pynccl depends on cuda libraries so we had to switch to the runtime image
-# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
-FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
-
-# libnccl required for ray
-RUN apt-get update -y \
- && apt-get install -y python3-pip
-
-WORKDIR /workspace
-COPY requirements.txt requirements.txt
+# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/pip \
- pip install -r requirements.txt
+ pip install -r requirements-dev.txt
-# Install flash attention (from pre-built wheel)
-RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
- pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
-
-#################### RUNTIME BASE IMAGE ####################
+# doc requires source code
+# we hide them inside `test_docs/` , so that this source code
+# will not be imported by other tests
+RUN mkdir test_docs
+RUN mv docs test_docs/
+RUN mv vllm test_docs/
+#################### TEST IMAGE ####################
#################### OPENAI API SERVER ####################
# openai api server alternative
FROM vllm-base AS vllm-openai
+
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate hf_transfer modelscope
-COPY --from=build /workspace/vllm/*.so /workspace/vllm/
-COPY vllm vllm
-
ENV VLLM_USAGE_SOURCE production-docker-image
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
diff --git a/Dockerfile.cpu b/Dockerfile.cpu
new file mode 100644
index 0000000000000..4251fddd6cc3b
--- /dev/null
+++ b/Dockerfile.cpu
@@ -0,0 +1,20 @@
+# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
+
+FROM ubuntu:22.04
+
+RUN apt-get update -y \
+ && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
+ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
+
+RUN pip install --upgrade pip \
+ && pip install wheel packaging ninja setuptools>=49.4.0 numpy
+
+COPY ./ /workspace/vllm
+
+WORKDIR /workspace/vllm
+
+RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
+
+RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
+
+CMD ["/bin/bash"]
diff --git a/Dockerfile.rocm b/Dockerfile.rocm
index 65a367994f960..10b8bf1e7fabd 100644
--- a/Dockerfile.rocm
+++ b/Dockerfile.rocm
@@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"
+# whether to build triton on rocm
+ARG BUILD_TRITON="1"
+
# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
@@ -75,6 +78,17 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
+# build triton
+RUN if [ "$BUILD_TRITON" = "1" ]; then \
+ mkdir -p libs \
+ && cd libs \
+ && pip uninstall -y triton \
+ && git clone https://github.com/ROCm/triton.git \
+ && cd triton/python \
+ && pip3 install . \
+ && cd ../..; \
+ fi
+
COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip
diff --git a/MANIFEST.in b/MANIFEST.in
index aa16da6500e6c..d385f194c6c0f 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,6 @@
include LICENSE
-include requirements.txt
+include requirements-common.txt
+include requirements-cuda.txt
include CMakeLists.txt
recursive-include cmake *
diff --git a/README.md b/README.md
index 08e46b68cb7ce..d53227b82d87a 100644
--- a/README.md
+++ b/README.md
@@ -14,18 +14,8 @@ Easy, fast, and cheap LLM serving for everyone
----
-
-**The Third vLLM Bay Area Meetup (April 2nd 6pm-8:30pm PT)**
-
-We are thrilled to announce our third vLLM Meetup!
-The vLLM team will share recent updates and roadmap.
-We will also have vLLM collaborators from Roblox coming up to the stage to discuss their experience in deploying LLMs with vLLM.
-Please register [here](https://robloxandvllmmeetup2024.splashthat.com/) and join us!
-
----
-
*Latest News* 🔥
+- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
- [2024/01] Added ROCm 6.0 support to vLLM.
- [2023/12] Added ROCm 5.7 support to vLLM.
@@ -80,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
+- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
@@ -88,7 +79,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
-- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
+- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index 96a372e5511b7..ad428bd1c3644 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -334,7 +334,8 @@ async def async_request_openai_chat_completions(
timestamp = time.perf_counter()
data = json.loads(chunk)
- if "content" in data["choices"][0]["delta"]:
+ delta = data["choices"][0]["delta"]
+ if delta.get("content", None):
# First token
if ttft == 0:
ttft = time.perf_counter() - st
@@ -345,8 +346,7 @@ async def async_request_openai_chat_completions(
output.itl.append(timestamp -
most_recent_timestamp)
- generated_text += data["choices"][0]["delta"][
- "content"]
+ generated_text += delta["content"]
most_recent_timestamp = timestamp
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
index da02493b17fd3..91510dafc57a5 100644
--- a/benchmarks/benchmark_latency.py
+++ b/benchmarks/benchmark_latency.py
@@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
+ quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
enable_chunked_prefill=args.enable_chunked_prefill,
@@ -67,7 +68,8 @@ def main(args: argparse.Namespace):
return latency
print("Warming up...")
- run_to_completion(profile_dir=None)
+ for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
+ run_to_completion(profile_dir=None)
if args.profile:
profile_dir = args.profile_result_dir
@@ -83,7 +85,12 @@ def main(args: argparse.Namespace):
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))
+ latencies = np.array(latencies)
+ percentages = [10, 25, 50, 75, 90]
+ percentiles = np.percentile(latencies, percentages)
print(f'Avg latency: {np.mean(latencies)} seconds')
+ for percentage, percentile in zip(percentages, percentiles):
+ print(f'{percentage}% percentile latency: {percentile} seconds')
if __name__ == '__main__':
@@ -105,9 +112,13 @@ if __name__ == '__main__':
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
+ parser.add_argument('--num-iters-warmup',
+ type=int,
+ default=10,
+ help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
type=int,
- default=3,
+ default=30,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code',
action='store_true',
@@ -127,10 +138,23 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
- choices=['auto', 'fp8_e5m2'],
+ choices=['auto', 'fp8'],
default='auto',
help=
- 'Data type for kv cache storage. If "auto", will use model data type.')
+ 'Data type for kv cache storage. If "auto", will use model data type. '
+ 'FP8_E5M2 (without scaling) is only supported on cuda version greater '
+ 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
+ 'common inference criteria.')
+ parser.add_argument(
+ '--quantization-param-path',
+ type=str,
+ default=None,
+ help='Path to the JSON file containing the KV cache scaling factors. '
+ 'This should generally be supplied, when KV cache dtype is FP8. '
+ 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
+ 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
+ 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
+ 'instead supported for common inference criteria.')
parser.add_argument(
'--profile',
action='store_true',
@@ -145,8 +169,8 @@ if __name__ == '__main__':
"--device",
type=str,
default="cuda",
- choices=["cuda"],
- help='device type for vLLM execution, supporting CUDA only currently.')
+ choices=["cuda", "cpu"],
+ help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument('--block-size',
type=int,
default=16,
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index bc7812ed4119e..6054df439fa57 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -110,7 +110,9 @@ def sample_sonnet_requests(
prefix_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, str, int, int]]:
- assert input_len > prefix_len, "input_len must be greater than prefix_len."
+ assert (
+ input_len > prefix_len
+ ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
# Load the dataset.
with open(dataset_path) as f:
@@ -131,8 +133,9 @@ def sample_sonnet_requests(
base_message, add_generation_prompt=True, tokenize=False)
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
- assert (input_len > base_prompt_offset
- ), f"Please set 'args.input-len' higher than {base_prompt_offset}."
+ assert (
+ input_len > base_prompt_offset
+ ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
num_input_lines = round(
(input_len - base_prompt_offset) / average_poem_len)
@@ -140,7 +143,7 @@ def sample_sonnet_requests(
# prompt are fixed poem lines.
assert (
prefix_len > base_prompt_offset
- ), f"Please set 'args.prefix-len' higher than {base_prompt_offset}."
+ ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
num_prefix_lines = round(
(prefix_len - base_prompt_offset) / average_poem_len)
@@ -373,9 +376,9 @@ def main(args: argparse.Namespace):
input_requests = sample_sonnet_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
- input_len=args.input_len,
- output_len=args.output_len,
- prefix_len=args.prefix_len,
+ input_len=args.sonnet_input_len,
+ output_len=args.sonnet_output_len,
+ prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
)
input_requests = [(prompt, prompt_len, output_len)
@@ -388,9 +391,9 @@ def main(args: argparse.Namespace):
input_requests = sample_sonnet_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
- input_len=args.input_len,
- output_len=args.output_len,
- prefix_len=args.prefix_len,
+ input_len=args.sonnet_input_len,
+ output_len=args.sonnet_output_len,
+ prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
)
input_requests = [(prompt_formatted, prompt_len, output_len)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index 9d84bde17d6d0..e71338273d1e5 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -29,22 +29,23 @@ def sample_requests(
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
- # Tokenize the prompts and completions.
- prompts = [prompt for prompt, _ in dataset]
- prompt_token_ids = tokenizer(prompts).input_ids
- completions = [completion for _, completion in dataset]
- completion_token_ids = tokenizer(completions).input_ids
- tokenized_dataset = []
- for i in range(len(dataset)):
- output_len = len(completion_token_ids[i])
- if fixed_output_len is not None:
- output_len = fixed_output_len
- tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
+ # Shuffle the dataset.
+ random.shuffle(dataset)
- # Filter out too long sequences.
+ # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
- for prompt, prompt_token_ids, output_len in tokenized_dataset:
+ for i in range(len(dataset)):
+ if len(filtered_dataset) == num_requests:
+ break
+
+ # Tokenize the prompts and completions.
+ prompt = dataset[i][0]
+ prompt_token_ids = tokenizer(prompt).input_ids
+ completion = dataset[i][1]
+ completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
+ output_len = len(completion_token_ids
+ ) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
@@ -53,9 +54,7 @@ def sample_requests(
continue
filtered_dataset.append((prompt, prompt_len, output_len))
- # Sample the requests.
- sampled_requests = random.sample(filtered_dataset, num_requests)
- return sampled_requests
+ return filtered_dataset
def run_vllm(
@@ -72,6 +71,7 @@ def run_vllm(
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
+ quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9,
@@ -89,6 +89,7 @@ def run_vllm(
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
+ quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir)
@@ -217,7 +218,8 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
- args.kv_cache_dtype, args.device,
+ args.kv_cache_dtype,
+ args.quantization_param_path, args.device,
args.enable_prefix_caching,
args.gpu_memory_utilization, args.download_dir)
elif args.backend == "hf":
@@ -306,16 +308,29 @@ if __name__ == "__main__":
parser.add_argument(
"--kv-cache-dtype",
type=str,
- choices=["auto", "fp8_e5m2"],
+ choices=["auto", "fp8"],
default="auto",
help=
- 'Data type for kv cache storage. If "auto", will use model data type.')
+ 'Data type for kv cache storage. If "auto", will use model data type. '
+ 'FP8_E5M2 (without scaling) is only supported on cuda version greater '
+ 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
+ 'common inference criteria.')
+ parser.add_argument(
+ '--quantization-param-path',
+ type=str,
+ default=None,
+ help='Path to the JSON file containing the KV cache scaling factors. '
+ 'This should generally be supplied, when KV cache dtype is FP8. '
+ 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
+ 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
+ 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
+ 'instead supported for common inference criteria.')
parser.add_argument(
"--device",
type=str,
default="cuda",
- choices=["cuda"],
- help='device type for vLLM execution, supporting CUDA only currently.')
+ choices=["cuda", "cpu"],
+ help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py
index f6c8f900a3bff..f71d1fcaaef50 100644
--- a/benchmarks/kernels/benchmark_paged_attention.py
+++ b/benchmarks/kernels/benchmark_paged_attention.py
@@ -97,6 +97,9 @@ def main(
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()
+ # Using default kv_scale
+ kv_scale = 1.0
+
for _ in range(num_iters):
if version == "v1":
ops.paged_attention_v1(
@@ -112,6 +115,7 @@ def main(
max_context_len,
alibi_slopes,
kv_cache_dtype,
+ kv_scale,
)
elif version == "v2":
ops.paged_attention_v2(
@@ -130,6 +134,7 @@ def main(
max_context_len,
alibi_slopes,
kv_cache_dtype,
+ kv_scale,
)
else:
raise ValueError(f"Invalid version: {version}")
@@ -179,11 +184,13 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
- choices=["auto", "fp8_e5m2"],
+ choices=["auto", "fp8"],
default="auto",
help=
- 'Data type for kv cache storage. If "auto", will use model data type.')
- parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
+ 'Data type for kv cache storage. If "auto", will use model data type. '
+ 'FP8_E5M2 (without scaling) is only supported on cuda version greater '
+ 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
+ 'common inference criteria.')
args = parser.parse_args()
print(args)
diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake
new file mode 100644
index 0000000000000..0cf37769a6960
--- /dev/null
+++ b/cmake/cpu_extension.cmake
@@ -0,0 +1,90 @@
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+#
+# Define environment variables for special configurations
+#
+if(DEFINED ENV{VLLM_CPU_AVX512BF16})
+ set(ENABLE_AVX512BF16 ON)
+endif()
+
+include_directories("${CMAKE_SOURCE_DIR}/csrc")
+
+#
+# Check the compile flags
+#
+list(APPEND CXX_COMPILE_FLAGS
+ "-fopenmp"
+ "-DVLLM_CPU_EXTENSION")
+
+execute_process(COMMAND cat /proc/cpuinfo
+ RESULT_VARIABLE CPUINFO_RET
+ OUTPUT_VARIABLE CPUINFO)
+
+if (NOT CPUINFO_RET EQUAL 0)
+ message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
+endif()
+
+function (find_isa CPUINFO TARGET OUT)
+ string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
+ if(NOT ISA_FOUND EQUAL -1)
+ set(${OUT} ON PARENT_SCOPE)
+ else()
+ set(${OUT} OFF PARENT_SCOPE)
+ endif()
+endfunction()
+
+find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
+
+if (AVX512_FOUND)
+ list(APPEND CXX_COMPILE_FLAGS
+ "-mavx512f"
+ "-mavx512vl"
+ "-mavx512bw"
+ "-mavx512dq")
+
+ find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
+ if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
+ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
+ CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
+ list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
+ else()
+ message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
+ endif()
+ else()
+ message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
+ endif()
+else()
+ message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
+endif()
+
+message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
+
+
+#
+# Define extension targets
+#
+
+#
+# _C extension
+#
+set(VLLM_EXT_SRC
+ "csrc/cpu/activation.cpp"
+ "csrc/cpu/attention.cpp"
+ "csrc/cpu/cache.cpp"
+ "csrc/cpu/layernorm.cpp"
+ "csrc/cpu/pos_encoding.cpp"
+ "csrc/cpu/pybind.cpp")
+
+define_gpu_extension_target(
+ _C
+ DESTINATION vllm
+ LANGUAGE CXX
+ SOURCES ${VLLM_EXT_SRC}
+ COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
+ WITH_SOABI
+)
+
+add_custom_target(default)
+message(STATUS "Enabling C extension.")
+add_dependencies(default _C)
+
diff --git a/cmake/utils.cmake b/cmake/utils.cmake
index c7d3d85389838..7c71673e36f29 100644
--- a/cmake/utils.cmake
+++ b/cmake/utils.cmake
@@ -100,6 +100,8 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
+ endif()
+ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__"
@@ -117,6 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
+ "-DENABLE_FP8_E4M3"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")
diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h
index 61748e6b1eee6..64f86381d9db9 100644
--- a/csrc/attention/attention_dtypes.h
+++ b/csrc/attention/attention_dtypes.h
@@ -4,4 +4,4 @@
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
-#include "dtype_fp8_e5m2.cuh"
+#include "dtype_fp8.cuh"
diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu
index 5e61668d5cc1a..f3a5bbfd3098d 100644
--- a/csrc/attention/attention_kernels.cu
+++ b/csrc/attention/attention_kernels.cu
@@ -22,12 +22,26 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
-#ifdef ENABLE_FP8_E5M2
+
+#if defined(ENABLE_FP8_E5M2)
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+#elif defined(ENABLE_FP8_E4M3)
+#include "../quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include
+#ifdef USE_ROCM
+ #include
+ typedef __hip_bfloat16 __nv_bfloat16;
+#endif
+
+#ifndef USE_ROCM
+#define WARP_SIZE 32
+#else
+#define WARP_SIZE warpSize
+#endif
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
@@ -78,7 +92,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
- bool IS_FP8_E5M2_KV_CACHE,
+ bool IS_FP8_KV_CACHE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
@@ -95,7 +109,8 @@ __device__ void paged_attention_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
- const int kv_head_stride) {
+ const int kv_head_stride,
+ const float kv_scale) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
@@ -142,7 +157,7 @@ __device__ void paged_attention_kernel(
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec::Type;
using Q_vec = typename Vec::Type;
-#ifdef ENABLE_FP8_E5M2
+#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using Quant_vec = typename Vec::Type;
#endif
@@ -208,11 +223,16 @@ __device__ void paged_attention_kernel(
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
- if constexpr (IS_FP8_E5M2_KV_CACHE) {
-#ifdef ENABLE_FP8_E5M2
+ if constexpr (IS_FP8_KV_CACHE) {
+#if defined(ENABLE_FP8_E5M2)
Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec.
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant);
+#elif defined(ENABLE_FP8_E4M3)
+ Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+ // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
+ // cache vec to k vec in higher precision (FP16, BFloat16, etc.)
+ k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale);
#else
assert(false);
#endif
@@ -292,7 +312,7 @@ __device__ void paged_attention_kernel(
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec::Type;
using L_vec = typename Vec::Type;
-#ifdef ENABLE_FP8_E5M2
+#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using V_quant_vec = typename Vec::Type;
#endif
using Float_L_vec = typename FloatVec::Type;
@@ -328,11 +348,16 @@ __device__ void paged_attention_kernel(
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
- if constexpr (IS_FP8_E5M2_KV_CACHE) {
-#ifdef ENABLE_FP8_E5M2
+ if constexpr (IS_FP8_KV_CACHE) {
+#if defined(ENABLE_FP8_E5M2)
V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec);
+#elif defined(ENABLE_FP8_E4M3)
+ V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset);
+ // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
+ // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
+ v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale);
#else
assert(false);
#endif
@@ -423,7 +448,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
- bool IS_FP8_E5M2_KV_CACHE>
+ bool IS_FP8_KV_CACHE>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@@ -437,11 +462,12 @@ __global__ void paged_attention_v1_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
- const int kv_head_stride) {
- paged_attention_kernel(
+ const int kv_head_stride,
+ const float kv_scale) {
+ paged_attention_kernel(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
- max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
+ max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
@@ -451,7 +477,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
- bool IS_FP8_E5M2_KV_CACHE,
+ bool IS_FP8_KV_CACHE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
@@ -468,11 +494,12 @@ __global__ void paged_attention_v2_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
- const int kv_head_stride) {
- paged_attention_kernel(
+ const int kv_head_stride,
+ const float kv_scale) {
+ paged_attention_kernel(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
- q_stride, kv_block_stride, kv_head_stride);
+ q_stride, kv_block_stride, kv_head_stride, kv_scale);
}
// Grid: (num_heads, num_seqs).
@@ -579,9 +606,9 @@ __global__ void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \
+ IS_FP8_KV_CACHE>), shared_mem_size); \
vllm::paged_attention_v1_kernel<<>>( \
+ IS_FP8_KV_CACHE><<>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
@@ -594,14 +621,15 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
- kv_head_stride);
+ kv_head_stride, \
+ kv_scale);
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
- bool IS_FP8_E5M2_KV_CACHE,
+ bool IS_FP8_KV_CACHE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out,
@@ -613,7 +641,8 @@ void paged_attention_v1_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
- const c10::optional& alibi_slopes) {
+ const c10::optional& alibi_slopes,
+ float kv_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -677,8 +706,8 @@ void paged_attention_v1_launcher(
}
}
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
- paged_attention_v1_launcher( \
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+ paged_attention_v1_launcher( \
out, \
query, \
key_cache, \
@@ -688,20 +717,21 @@ void paged_attention_v1_launcher(
block_tables, \
context_lens, \
max_context_len, \
- alibi_slopes);
+ alibi_slopes, \
+ kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
switch (block_size) { \
case 8: \
- CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
+ CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
break; \
case 16: \
- CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
+ CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
break; \
case 32: \
- CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
+ CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@@ -720,7 +750,8 @@ void paged_attention_v1(
int block_size,
int max_context_len,
const c10::optional& alibi_slopes,
- const std::string& kv_cache_dtype) {
+ const std::string& kv_cache_dtype,
+ float kv_scale) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
@@ -731,7 +762,7 @@ void paged_attention_v1(
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
- } else if (kv_cache_dtype == "fp8_e5m2") {
+ } else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
@@ -748,7 +779,7 @@ void paged_attention_v1(
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel \
+ IS_FP8_KV_CACHE, PARTITION_SIZE> \
<<>>( \
exp_sums_ptr, \
max_logits_ptr, \
@@ -764,7 +795,8 @@ void paged_attention_v1(
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
- kv_head_stride); \
+ kv_head_stride, \
+ kv_scale); \
vllm::paged_attention_v2_reduce_kernel \
<<>>( \
out_ptr, \
@@ -778,7 +810,7 @@ template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
- bool IS_FP8_E5M2_KV_CACHE,
+ bool IS_FP8_KV_CACHE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
@@ -794,7 +826,8 @@ void paged_attention_v2_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
- const c10::optional& alibi_slopes) {
+ const c10::optional& alibi_slopes,
+ float kv_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -864,8 +897,8 @@ void paged_attention_v2_launcher(
}
}
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
- paged_attention_v2_launcher( \
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+ paged_attention_v2_launcher( \
out, \
exp_sums, \
max_logits, \
@@ -878,20 +911,21 @@ void paged_attention_v2_launcher(
block_tables, \
context_lens, \
max_context_len, \
- alibi_slopes);
+ alibi_slopes, \
+ kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
switch (block_size) { \
case 8: \
- CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
+ CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
break; \
case 16: \
- CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
+ CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
break; \
case 32: \
- CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
+ CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@@ -913,7 +947,8 @@ void paged_attention_v2(
int block_size,
int max_context_len,
const c10::optional& alibi_slopes,
- const std::string& kv_cache_dtype) {
+ const std::string& kv_cache_dtype,
+ float kv_scale) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
@@ -924,7 +959,7 @@ void paged_attention_v2(
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
- } else if (kv_cache_dtype == "fp8_e5m2") {
+ } else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
diff --git a/csrc/attention/dtype_fp8_e5m2.cuh b/csrc/attention/dtype_fp8.cuh
similarity index 89%
rename from csrc/attention/dtype_fp8_e5m2.cuh
rename to csrc/attention/dtype_fp8.cuh
index 0580fbb8e863f..d11dee91ebe87 100644
--- a/csrc/attention/dtype_fp8_e5m2.cuh
+++ b/csrc/attention/dtype_fp8.cuh
@@ -8,7 +8,7 @@
#endif
namespace vllm {
-#ifdef ENABLE_FP8_E5M2
+#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
template<>
diff --git a/csrc/cache.h b/csrc/cache.h
index 765e231abd26f..718a5f6cfd7f7 100644
--- a/csrc/cache.h
+++ b/csrc/cache.h
@@ -21,9 +21,10 @@ void reshape_and_cache(
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
- const std::string& kv_cache_dtype);
+ const std::string& kv_cache_dtype,
+ const float kv_scale);
// Just for unittest
-void convert_fp8_e5m2(
+void convert_fp8(
torch::Tensor& src_cache,
torch::Tensor& dst_cache);
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 7254010b8e3a9..24aaa2ff3e263 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -4,8 +4,10 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
-#ifdef ENABLE_FP8_E5M2
+#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+#elif defined(ENABLE_FP8_E4M3)
+#include "quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include
@@ -151,7 +153,7 @@ void copy_blocks(
namespace vllm {
-template
+template
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
@@ -163,7 +165,8 @@ __global__ void reshape_and_cache_kernel(
const int num_heads,
const int head_size,
const int block_size,
- const int x) {
+ const int x,
+ const float kv_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
@@ -195,10 +198,13 @@ __global__ void reshape_and_cache_kernel(
+ block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
- if constexpr (is_fp8_e5m2_kv_cache) {
-#ifdef ENABLE_FP8_E5M2
+ if constexpr (is_fp8_kv_cache) {
+#if defined(ENABLE_FP8_E5M2)
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value);
+#elif defined(ENABLE_FP8_E4M3)
+ key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale);
+ value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale);
#else
assert(false);
#endif
@@ -211,8 +217,8 @@ __global__ void reshape_and_cache_kernel(
} // namespace vllm
-#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
- vllm::reshape_and_cache_kernel<<>>( \
+#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
+ vllm::reshape_and_cache_kernel<<>>( \
reinterpret_cast(key.data_ptr()), \
reinterpret_cast(value.data_ptr()), \
reinterpret_cast(key_cache.data_ptr()), \
@@ -223,7 +229,8 @@ __global__ void reshape_and_cache_kernel(
num_heads, \
head_size, \
block_size, \
- x);
+ x, \
+ kv_scale);
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@@ -231,7 +238,8 @@ void reshape_and_cache(
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
- const std::string& kv_cache_dtype)
+ const std::string& kv_cache_dtype,
+ const float kv_scale)
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
@@ -254,7 +262,7 @@ void reshape_and_cache(
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
}
- } else if (kv_cache_dtype == "fp8_e5m2") {
+ } else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
@@ -270,15 +278,17 @@ void reshape_and_cache(
namespace vllm {
template
-__global__ void convert_fp8_e5m2_kernel(
+__global__ void convert_fp8_kernel(
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
-#ifdef ENABLE_FP8_E5M2
+#if defined(ENABLE_FP8_E5M2)
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]);
+#elif defined(ENABLE_FP8_E4M3)
+ dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]);
#else
assert(false);
#endif
@@ -287,16 +297,25 @@ __global__ void convert_fp8_e5m2_kernel(
} // namespace vllm
-#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
- vllm::convert_fp8_e5m2_kernel<<>>( \
- reinterpret_cast(src_cache.data_ptr()), \
- reinterpret_cast(dst_cache.data_ptr()), \
+#define CALL_CONVERT_FP8(Tout, Tin) \
+ vllm::convert_fp8_kernel<<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst_cache.data_ptr()), \
block_stride);
-void convert_fp8_e5m2(
+void convert_fp8(
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
{
+ torch::Device src_device = src_cache.device();
+ torch::Device dst_device = dst_cache.device();
+ TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
+ TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
+ TORCH_CHECK(
+ src_device.index() == dst_device.index(),
+ "src and dst must be on the same GPU");
+ at::cuda::OptionalCUDAGuard device_guard(src_device);
+
int64_t num_blocks = src_cache.size(0);
int64_t block_stride = src_cache.stride(0);
@@ -305,16 +324,16 @@ void convert_fp8_e5m2(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (src_cache.dtype() == at::ScalarType::Float) {
- CALL_CONVERT_FP8_E5M2(uint8_t, float);
+ CALL_CONVERT_FP8(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) {
- CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
+ CALL_CONVERT_FP8(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
- CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
+ CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
- CALL_CONVERT_FP8_E5M2(float, uint8_t);
+ CALL_CONVERT_FP8(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
- CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
+ CALL_CONVERT_FP8(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
- CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
+ CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
}
}
diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp
new file mode 100644
index 0000000000000..1bd24eb79d129
--- /dev/null
+++ b/csrc/cpu/activation.cpp
@@ -0,0 +1,148 @@
+#include "cpu_types.hpp"
+
+namespace {
+template
+void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
+ scalar_t *__restrict__ output) {
+ using scalar_vec_t = vec_op::vec_t;
+ constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
+
+ TORCH_CHECK(d % VEC_ELEM_NUM == 0);
+
+#pragma omp parallel for
+ for (int i = 0; i < num_tokens; ++i) {
+ for (int j = 0; j < d; j += VEC_ELEM_NUM) {
+ int start = i * d;
+ if constexpr (is_gated) {
+ start *= 2;
+ }
+
+ const scalar_vec_t x(input + start + j);
+ const vec_op::FP32Vec8 f32_x(x);
+ vec_op::FP32Vec8 f32_ans = func(f32_x);
+
+ if constexpr (is_gated) {
+ const scalar_vec_t y(input + start + d + j);
+ const vec_op::FP32Vec8 f32_y(y);
+ f32_ans = f32_y * f32_ans;
+ }
+
+ const scalar_vec_t result(f32_ans);
+ result.save(output + i * d + j);
+ }
+ }
+}
+
+FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
+ const vec_op::FP32Vec8 zeros(0.0);
+ const vec_op::FP32Vec8 ones(1.0);
+ return x / (ones + (zeros - x).exp());
+}
+
+FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
+ const vec_op::FP32Vec8 ones(1.0);
+ const vec_op::FP32Vec8 w1(0.79788456f);
+ const vec_op::FP32Vec8 w2(0.044715f);
+ const vec_op::FP32Vec8 w3(0.5);
+ const vec_op::FP32Vec8 x3 = x * x * x;
+ const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
+ return w3 * x * (ones + t);
+}
+
+FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
+ const vec_op::FP32Vec8 ones(1.0);
+ const vec_op::FP32Vec8 w1(0.79788456f);
+ const vec_op::FP32Vec8 w2(0.044715f);
+ const vec_op::FP32Vec8 w3(0.5);
+ const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
+ return w3 * x * (ones + t);
+}
+
+FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
+ const vec_op::FP32Vec8 ones(1.0);
+ const vec_op::FP32Vec8 w1(M_SQRT1_2);
+ const vec_op::FP32Vec8 w2(0.5);
+ return x * w2 * (ones + (x * w1).er());
+}
+
+FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
+ const vec_op::FP32Vec8 ones(1.0);
+ const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
+ const vec_op::FP32Vec8 w2(0.5);
+ const vec_op::FP32Vec8 w3(0.044715);
+ const vec_op::FP32Vec8 x_3 = x * x * x;
+ const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
+ return x * w2 * (ones + inner.tanh());
+}
+}; // namespace
+
+void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
+ int num_tokens = input.numel() / input.size(-1);
+ int d = input.size(-1) / 2;
+
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "silu_and_mul_impl", [&] {
+ CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
+ activation_kernel(num_tokens, d,
+ input.data_ptr(),
+ out.data_ptr());
+ CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
+ });
+}
+
+void gelu_and_mul(torch::Tensor &out, // [..., d]
+ torch::Tensor &input) // [..., 2 * d]
+{
+ int num_tokens = input.numel() / input.size(-1);
+ int d = input.size(-1) / 2;
+
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "gelu_and_mul_impl", [&] {
+ CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
+ activation_kernel(num_tokens, d,
+ input.data_ptr(),
+ out.data_ptr());
+ CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
+ });
+}
+
+void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
+ torch::Tensor &input) // [..., 2 * d]
+{
+ int num_tokens = input.numel() / input.size(-1);
+ int d = input.size(-1) / 2;
+
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
+ CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
+ activation_kernel(
+ num_tokens, d, input.data_ptr(),
+ out.data_ptr());
+ CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
+ });
+}
+
+void gelu_new(torch::Tensor &out, torch::Tensor &input) {
+ int num_tokens = input.numel() / input.size(-1);
+ int d = input.size(-1);
+
+ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
+ CPU_KERNEL_GUARD_IN(gelu_new_impl)
+ activation_kernel(
+ num_tokens, d, input.data_ptr(), out.data_ptr());
+ CPU_KERNEL_GUARD_OUT(gelu_new_impl)
+ });
+}
+
+void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
+ int num_tokens = input.numel() / input.size(-1);
+ int d = input.size(-1);
+
+ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
+ CPU_KERNEL_GUARD_IN(gelu_fast_impl)
+ activation_kernel(
+ num_tokens, d, input.data_ptr(), out.data_ptr());
+ CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
+ });
+}
diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp
new file mode 100644
index 0000000000000..365bbd5e23728
--- /dev/null
+++ b/csrc/cpu/attention.cpp
@@ -0,0 +1,746 @@
+#include "cpu_types.hpp"
+
+namespace {
+
+template struct KernelVecType {
+ using q_load_vec_type = void;
+ using q_vec_type = void;
+ using k_load_vec_type = void;
+ using k_vec_type = void;
+ using qk_acc_vec_type = void;
+ using v_load_vec_type = void;
+};
+
+template <> struct KernelVecType {
+ using q_load_vec_type = vec_op::FP32Vec4;
+ using q_vec_type = vec_op::FP32Vec16;
+ using k_load_vec_type = vec_op::FP32Vec16;
+ using k_vec_type = vec_op::FP32Vec16;
+ using qk_acc_vec_type = vec_op::FP32Vec16;
+ using v_load_vec_type = vec_op::FP32Vec16;
+};
+
+#ifdef __AVX512BF16__
+template <> struct KernelVecType {
+ using q_load_vec_type = vec_op::BF16Vec8;
+ using q_vec_type = vec_op::BF16Vec32;
+ using k_load_vec_type = vec_op::BF16Vec32;
+ using k_vec_type = vec_op::BF16Vec32;
+ using qk_acc_vec_type = vec_op::FP32Vec16;
+ using v_load_vec_type = vec_op::BF16Vec16;
+};
+#else
+template <> struct KernelVecType {
+ using q_load_vec_type = vec_op::BF16Vec8;
+ using q_vec_type = vec_op::FP32Vec16;
+ using k_load_vec_type = vec_op::BF16Vec16;
+ using k_vec_type = vec_op::FP32Vec16;
+ using qk_acc_vec_type = vec_op::FP32Vec16;
+ using v_load_vec_type = vec_op::BF16Vec16;
+};
+#endif
+
+template
+FORCE_INLINE std::pair reduceSoftmax(T *data, const int size,
+ const int capacity) {
+ T max = data[0];
+ for (int i = 1; i < size; ++i) {
+ max = max >= data[i] ? max : data[i];
+ }
+
+ T sum = 0;
+ for (int i = 0; i < size; ++i) {
+ data[i] = std::exp(data[i] - max);
+ sum += data[i];
+ }
+
+ int i = 0;
+ for (; i < size; ++i) {
+ data[i] /= sum;
+ }
+
+ for (; i < capacity; ++i) {
+ data[i] = 0;
+ }
+
+ return {max, sum};
+}
+
+template
+FORCE_INLINE std::pair
+reduceSoftmaxAlibi(T *data, const int size, const int capacity,
+ const float alibi_slope, const int start_index,
+ const int context_len) {
+ data[0] += alibi_slope * (start_index - context_len + 1);
+ T max = data[0];
+ for (int i = 1; i < size; ++i) {
+ T qk = data[i] + alibi_slope * (start_index + i - context_len + 1);
+ data[i] = qk;
+ max = max >= qk ? max : qk;
+ }
+
+ T sum = 0;
+ for (int i = 0; i < size; ++i) {
+ data[i] = std::exp(data[i] - max);
+ sum += data[i];
+ }
+
+ int i = 0;
+ for (; i < size; ++i) {
+ data[i] /= sum;
+ }
+
+ for (; i < capacity; ++i) {
+ data[i] = 0;
+ }
+
+ return {max, sum};
+}
+
+template
+FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
+ const int size) {
+ T max = max_data[0];
+ for (int i = 1; i < size; ++i) {
+ max = max >= max_data[i] ? max : max_data[i];
+ }
+
+ T rescaled_sum = 0;
+ for (int i = 0; i < size; ++i) {
+ T rescale_factor = std::exp(max_data[i] - max);
+ rescaled_sum += rescale_factor * sum_data[i];
+ sum_data[i] *= rescale_factor;
+ }
+ for (int i = 0; i < size; ++i) {
+ sum_data[i] /= rescaled_sum + 1e-8;
+ }
+}
+
+template
+struct reduceQKBlockKernel {
+ using q_load_vec_type = typename KernelVecType::q_load_vec_type;
+ using q_vec_type = typename KernelVecType::q_vec_type;
+ using k_load_vec_type = typename KernelVecType::k_load_vec_type;
+ using k_vec_type = typename KernelVecType::k_vec_type;
+ using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type;
+
+ constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
+ constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
+ constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
+
+ static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
+ static_assert(k_load_vec_type::get_elem_num() % x == 0);
+ static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
+
+ FORCE_INLINE static void call(const scalar_t *__restrict__ q,
+ const scalar_t *__restrict__ k_block,
+ float *__restrict__ logits, float scale,
+ const int token_num) {
+ const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
+
+ qk_acc_vec_type group_accums[MAX_GROUP_NUM];
+ if (token_num == BLOCK_SIZE) {
+ for (int q_offset = 0; q_offset < HEAD_SIZE;
+ q_offset += x, k_block += x * BLOCK_SIZE) {
+ q_load_vec_type q_load_group_vec(q + q_offset);
+ q_vec_type q_group_vec(q_load_group_vec);
+
+ vec_op::unroll_loop(
+ [k_block, &q_group_vec, &group_accums](int token_group_idx) {
+ k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
+ TOKEN_PER_GROUP);
+ k_vec_type k_group_vec(k_load_group_vec);
+ vec_op::fma(group_accums[token_group_idx], q_group_vec,
+ k_group_vec);
+ vec_op::prefetch(k_block + x * BLOCK_SIZE +
+ token_group_idx * x * TOKEN_PER_GROUP);
+ });
+ }
+ } else {
+ for (int q_offset = 0; q_offset < HEAD_SIZE;
+ q_offset += x, k_block += x * BLOCK_SIZE) {
+ q_load_vec_type q_load_group_vec(q + q_offset);
+ q_vec_type q_group_vec(q_load_group_vec);
+ for (int token_group_start = 0; token_group_start < group_num;
+ token_group_start += UNROLL_GROUP_NUM) {
+ vec_op::unroll_loop(
+ [token_group_start, k_block, &q_group_vec,
+ &group_accums](int token_group_idx) {
+ token_group_idx += token_group_start;
+ k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
+ TOKEN_PER_GROUP);
+ k_vec_type k_group_vec(k_load_group_vec);
+ vec_op::fma(group_accums[token_group_idx], q_group_vec,
+ k_group_vec);
+ vec_op::prefetch(k_block + x * BLOCK_SIZE +
+ token_group_idx * x * TOKEN_PER_GROUP);
+ });
+ }
+ }
+ }
+
+ for (int token_group_idx = 0; token_group_idx < group_num;
+ ++token_group_idx) {
+ vec_op::unroll_loop(
+ [&group_accums, logits, scale, token_group_idx](int token_idx) {
+ float dot_v =
+ group_accums[token_group_idx]
+ .template reduce_sub_sum(token_idx);
+ logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
+ dot_v * scale;
+ });
+ }
+ }
+};
+
+template
+FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
+ acc_t &&acc) {
+ using v_load_vec_type = typename KernelVecType::v_load_vec_type;
+ constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
+ static_assert(BLOCK_SIZE == ELEM_NUM);
+ vec_op::FP32Vec16 prob_vec(prob);
+
+ vec_op::unroll_loop([&](int head_elem_idx) {
+ v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
+ vec_op::FP32Vec16 fp32_v_vec(v_vec);
+ acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
+ });
+}
+}; // namespace
+
+// Paged attention v1
+namespace {
+template
+struct paged_attention_v1_impl {
+ static void
+ call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
+ const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
+ const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
+ // head_size/x, block_size, x]
+ const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
+ // head_size, block_size]
+ const int num_kv_heads, const float scale,
+ const int
+ *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
+ const int *__restrict__ context_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float *__restrict__ alibi_slopes, // [num_heads]
+ const int q_stride, const int kv_block_stride, const int kv_head_stride,
+ const int num_seqs, const int num_heads) {
+ constexpr int x = 16 / sizeof(scalar_t);
+ const int num_queries_per_kv = num_heads / num_kv_heads;
+
+ static_assert(BLOCK_SIZE == 16);
+
+ int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
+ int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0;
+ TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0);
+
+ const int parallel_work_item_num = omp_get_max_threads();
+
+ size_t logits_bytes =
+ parallel_work_item_num * max_context_len_padded * sizeof(float);
+ float *logits = (float *)std::aligned_alloc(
+ 64, logits_bytes); // Cacheline alignment for each context token.
+ // [parallel_work_item_num, max_context_len_padded]
+
+#pragma omp parallel for collapse(2) schedule(dynamic, 1)
+ for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
+ for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
+ int context_len = context_lens[seq_idx];
+ const int *seq_block_table =
+ block_tables + max_num_blocks_per_seq * seq_idx;
+ const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
+ const int64_t kv_head_idx = head_idx / num_queries_per_kv;
+ const scalar_t *__restrict__ q_vec_ptr =
+ q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+ const int last_block_token_num =
+ context_len - (block_num - 1) * BLOCK_SIZE;
+ float *__restrict__ thread_block_logits =
+ logits + omp_get_thread_num() * max_context_len_padded;
+
+ // Compute logits
+ for (int block_idx = 0; block_idx < block_num; ++block_idx) {
+ const int64_t physical_block_idx = seq_block_table[block_idx];
+ const scalar_t *__restrict__ k_block_cache_ptr =
+ k_cache + physical_block_idx * kv_block_stride +
+ kv_head_idx * kv_head_stride;
+ float *__restrict__ head_block_logits =
+ thread_block_logits + block_idx * BLOCK_SIZE;
+
+ reduceQKBlockKernel::call(
+ q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
+ block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
+ }
+
+ // Compute softmax
+ if (alibi_slopes) {
+ reduceSoftmaxAlibi(thread_block_logits, context_len,
+ block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
+ context_len);
+ } else {
+ reduceSoftmax(thread_block_logits, context_len,
+ block_num * BLOCK_SIZE);
+ }
+
+ // Compute value
+ constexpr int head_elem_num_per_partition = 16;
+ constexpr int head_partition_num =
+ HEAD_SIZE / head_elem_num_per_partition;
+ for (int head_part_idx = 0; head_part_idx < head_partition_num;
+ ++head_part_idx) {
+ vec_op::FP32Vec16 accums[head_elem_num_per_partition];
+ scalar_t *__restrict__ out_ptr =
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
+ head_part_idx * head_elem_num_per_partition;
+ for (int block_idx = 0; block_idx < block_num; ++block_idx) {
+ const int64_t physical_block_idx = seq_block_table[block_idx];
+ const float *__restrict__ prob_vec_ptr =
+ thread_block_logits + block_idx * BLOCK_SIZE;
+ const scalar_t *__restrict__ v_block_cache_ptr =
+ v_cache + physical_block_idx * kv_block_stride +
+ kv_head_idx * kv_head_stride +
+ BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
+ reduceValueBlock(
+ prob_vec_ptr, v_block_cache_ptr, accums);
+
+ if (block_idx != block_num - 1) {
+ const int64_t next_physical_block_idx =
+ seq_block_table[block_idx + 1];
+ const scalar_t *__restrict__ next_v_block_cache_ptr =
+ v_cache + next_physical_block_idx * kv_block_stride +
+ kv_head_idx * kv_head_stride +
+ BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
+ vec_op::unroll_loop(
+ [&](int head_elem_idx) {
+ if (head_elem_idx % 2 == 0) {
+ vec_op::prefetch(next_v_block_cache_ptr +
+ BLOCK_SIZE * head_elem_idx);
+ }
+ });
+ }
+ }
+
+ vec_op::unroll_loop(
+ [&](int head_elem_idx) {
+ float value = accums[head_elem_idx].reduce_sum();
+ vec_op::storeFP32(value, out_ptr + head_elem_idx);
+ });
+ }
+ }
+ }
+ std::free(logits);
+ }
+};
+
+#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
+ paged_attention_v1_impl::call( \
+ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
+ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
+ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
+ num_heads);
+
+template
+void paged_attention_v1_impl_launcher(
+ torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
+ torch::Tensor &value_cache, int num_kv_heads, float scale,
+ torch::Tensor &block_tables, torch::Tensor &context_lens,
+ int max_context_len, const c10::optional &alibi_slopes) {
+ int num_seqs = query.size(0);
+ int num_heads = query.size(1);
+ int head_size = query.size(2);
+ int max_num_blocks_per_seq = block_tables.size(1);
+ int q_stride = query.stride(0);
+ int kv_block_stride = key_cache.stride(0);
+ int kv_head_stride = key_cache.stride(1);
+
+ // NOTE: alibi_slopes is optional.
+ const float *alibi_slopes_ptr =
+ alibi_slopes
+ ? reinterpret_cast(alibi_slopes.value().data_ptr())
+ : nullptr;
+
+ T *out_ptr = reinterpret_cast(out.data_ptr());
+ T *query_ptr = reinterpret_cast(query.data_ptr());
+ T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
+ T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
+ int *block_tables_ptr = block_tables.data_ptr();
+ int *context_lens_ptr = context_lens.data_ptr();
+
+ switch (head_size) {
+ case 64:
+ LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
+ break;
+ case 80:
+ LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
+ break;
+ case 96:
+ LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
+ break;
+ case 112:
+ LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
+ break;
+ case 128:
+ LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
+ break;
+ case 256:
+ LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
+ paged_attention_v1_impl_launcher( \
+ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
+ context_lens, max_context_len, alibi_slopes);
+
+#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
+ switch (block_size) { \
+ case 16: \
+ CALL_V1_KERNEL_LAUNCHER(T, 16); \
+ break; \
+ default: \
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+ break; \
+ }
+} // namespace
+
+void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
+ torch::Tensor &key_cache, torch::Tensor &value_cache,
+ int num_kv_heads, float scale,
+ torch::Tensor &block_tables,
+ torch::Tensor &context_lens, int block_size,
+ int max_context_len,
+ const c10::optional &alibi_slopes,
+ const std::string &kv_cache_dtype, float kv_scale) {
+ TORCH_CHECK(kv_scale == 1.0f);
+ VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
+ [&] {
+ CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
+ CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
+ CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
+ });
+}
+
+// Paged attention v2
+namespace {
+template
+struct paged_attention_v2_impl {
+ static void call(
+ scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
+ float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ float
+ *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
+ scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
+ // max_num_partitions, head_size]
+ const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
+ const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
+ // head_size/x, block_size, x]
+ const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
+ // head_size, block_size]
+ const int num_kv_heads, const float scale,
+ const int
+ *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
+ const int *__restrict__ context_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float *__restrict__ alibi_slopes, // [num_heads]
+ const int q_stride, const int kv_block_stride, const int kv_head_stride,
+ const int num_seqs, const int num_heads, const int max_num_partitions) {
+ constexpr int x = 16 / sizeof(scalar_t);
+ const int num_queries_per_kv = num_heads / num_kv_heads;
+
+ static_assert(BLOCK_SIZE == 16);
+ static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
+ static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
+
+#pragma omp parallel for collapse(3) schedule(static, 1)
+ for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
+ for (int partition_idx = 0; partition_idx < max_num_partitions;
+ ++partition_idx) {
+ for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
+ const int context_len = context_lens[seq_idx];
+ const int start_token_idx = partition_idx * PARTITION_SIZE;
+
+ if (start_token_idx >= context_len)
+ continue;
+
+ const int partition_num =
+ (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
+ const bool no_reduce = (partition_num == 1);
+ const int context_token_num =
+ (std::min(context_len, start_token_idx + PARTITION_SIZE) -
+ start_token_idx);
+ const int block_num =
+ (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
+ const int last_block_token_num =
+ context_token_num - (block_num - 1) * BLOCK_SIZE;
+ const int *seq_block_table = block_tables +
+ max_num_blocks_per_seq * seq_idx +
+ start_token_idx / BLOCK_SIZE;
+ const int64_t kv_head_idx = head_idx / num_queries_per_kv;
+ const scalar_t *__restrict__ q_vec_ptr =
+ q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+
+ float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
+
+ // Compute logits
+ for (int block_idx = 0; block_idx < block_num; ++block_idx) {
+ const int64_t physical_block_idx = seq_block_table[block_idx];
+ const scalar_t *__restrict__ k_block_cache_ptr =
+ k_cache + physical_block_idx * kv_block_stride +
+ kv_head_idx * kv_head_stride;
+ float *__restrict__ head_block_logits =
+ logits + block_idx * BLOCK_SIZE;
+
+ reduceQKBlockKernel::call(
+ q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
+ block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
+ }
+
+ std::pair max_and_sum;
+ if (alibi_slopes) {
+ max_and_sum = reduceSoftmaxAlibi(
+ logits, context_token_num, block_num * BLOCK_SIZE,
+ alibi_slopes[head_idx], start_token_idx, context_len);
+ } else {
+ max_and_sum = reduceSoftmax(logits, context_token_num,
+ block_num * BLOCK_SIZE);
+ }
+
+ auto &&[max_logit, exp_sum] = max_and_sum;
+
+ scalar_t *__restrict__ output_buffer = nullptr;
+ if (!no_reduce) {
+ auto idx = seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions + partition_idx;
+ max_logits[idx] = max_logit;
+ exp_sums[idx] = exp_sum;
+ output_buffer =
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ head_idx * max_num_partitions * HEAD_SIZE +
+ partition_idx * HEAD_SIZE;
+ } else {
+ output_buffer =
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ }
+
+ // Compute value
+ constexpr int head_elem_num_per_partition = 16;
+ constexpr int head_partition_num =
+ HEAD_SIZE / head_elem_num_per_partition;
+ for (int head_part_idx = 0; head_part_idx < head_partition_num;
+ ++head_part_idx) {
+ vec_op::FP32Vec16 accums[head_elem_num_per_partition];
+ scalar_t *__restrict__ out_ptr =
+ output_buffer + head_part_idx * head_elem_num_per_partition;
+ for (int block_idx = 0; block_idx < block_num; ++block_idx) {
+ const int64_t physical_block_idx = seq_block_table[block_idx];
+ const float *__restrict__ prob_vec_ptr =
+ logits + block_idx * BLOCK_SIZE;
+ const scalar_t *__restrict__ v_block_cache_ptr =
+ v_cache + physical_block_idx * kv_block_stride +
+ kv_head_idx * kv_head_stride +
+ BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
+ reduceValueBlock(
+ prob_vec_ptr, v_block_cache_ptr, accums);
+
+ if (block_idx != block_num - 1) {
+ const int64_t next_physical_block_idx =
+ seq_block_table[block_idx + 1];
+ const scalar_t *__restrict__ next_v_block_cache_ptr =
+ v_cache + next_physical_block_idx * kv_block_stride +
+ kv_head_idx * kv_head_stride +
+ BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
+ vec_op::unroll_loop(
+ [&](int head_elem_idx) {
+ if (head_elem_idx % 2 == 0) {
+ vec_op::prefetch(next_v_block_cache_ptr +
+ BLOCK_SIZE * head_elem_idx);
+ }
+ });
+ }
+ }
+
+ vec_op::unroll_loop(
+ [&](int head_elem_idx) {
+ float value = accums[head_elem_idx].reduce_sum();
+ vec_op::storeFP32(value, out_ptr + head_elem_idx);
+ });
+ }
+ }
+ }
+ }
+
+ // Rescale partition softmax and store the factors to exp_sums
+#pragma omp parallel for collapse(2) schedule(static, 1)
+ for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
+ for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
+ const int context_len = context_lens[seq_idx];
+ const int partition_num =
+ (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
+
+ if (partition_num == 1)
+ continue;
+
+ reducePartitonSoftmax(
+ max_logits + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions,
+ exp_sums + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions,
+ partition_num);
+ }
+ }
+
+ // Reduce values
+ using v_load_vec_type = typename KernelVecType::v_load_vec_type;
+ static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
+ constexpr int head_elem_num_per_group =
+ 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
+ // didn't align with 64 bytes
+ static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
+ constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
+ const float *__restrict__ rescale_factors = exp_sums;
+#pragma omp parallel for collapse(3) schedule(static, 1)
+ for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
+ for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
+ for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
+ const int context_len = context_lens[seq_idx];
+ const int partition_num =
+ (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
+
+ if (partition_num == 1)
+ continue;
+
+ const float *__restrict__ seq_head_rescale_factors =
+ rescale_factors + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions;
+ const scalar_t *__restrict__ seq_head_tmp_out =
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ head_idx * max_num_partitions * HEAD_SIZE +
+ group_idx * head_elem_num_per_group;
+ scalar_t *__restrict__ seq_head_output =
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
+ group_idx * head_elem_num_per_group;
+
+ vec_op::FP32Vec16 acc;
+ for (int i = 0; i < partition_num; ++i) {
+ vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
+ v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
+ vec_op::FP32Vec16 fp32_value(value);
+ acc = acc + fp32_value * rescale_factor;
+ }
+ v_load_vec_type cast_acc(acc);
+ cast_acc.save(seq_head_output);
+ }
+ }
+ }
+ }
+};
+
+#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
+ paged_attention_v2_impl::call( \
+ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
+ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
+ context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
+ kv_block_stride, kv_head_stride, num_seqs, num_heads, \
+ max_num_partitions);
+
+template
+void paged_attention_v2_impl_launcher(
+ torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
+ torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
+ torch::Tensor &value_cache, int num_kv_heads, float scale,
+ torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
+ int max_context_len, const c10::optional &alibi_slopes) {
+ int num_seqs = query.size(0);
+ int num_heads = query.size(1);
+ int head_size = query.size(2);
+ int max_num_blocks_per_seq = block_tables.size(1);
+ int q_stride = query.stride(0);
+ int kv_block_stride = key_cache.stride(0);
+ int kv_head_stride = key_cache.stride(1);
+ int max_num_partitions = exp_sums.size(-1);
+
+ // NOTE: alibi_slopes is optional.
+ const float *alibi_slopes_ptr =
+ alibi_slopes
+ ? reinterpret_cast(alibi_slopes.value().data_ptr())
+ : nullptr;
+
+ T *out_ptr = reinterpret_cast(out.data_ptr());
+ float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr());
+ float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr());
+ T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr());
+ T *query_ptr = reinterpret_cast(query.data_ptr());
+ T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
+ T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
+ int *block_tables_ptr = block_tables.data_ptr();
+ int *context_lens_ptr = context_lens.data_ptr();
+
+ switch (head_size) {
+ case 64:
+ LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
+ break;
+ case 80:
+ LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
+ break;
+ case 96:
+ LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
+ break;
+ case 112:
+ LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
+ break;
+ case 128:
+ LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
+ break;
+ case 256:
+ LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
+ paged_attention_v2_impl_launcher( \
+ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
+ num_kv_heads, scale, block_tables, context_lens, block_size, \
+ max_context_len, alibi_slopes);
+
+#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
+ switch (block_size) { \
+ case 16: \
+ CALL_V2_KERNEL_LAUNCHER(T, 16); \
+ break; \
+ default: \
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+ break; \
+ }
+} // namespace
+
+void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
+ torch::Tensor &max_logits, torch::Tensor &tmp_out,
+ torch::Tensor &query, torch::Tensor &key_cache,
+ torch::Tensor &value_cache, int num_kv_heads,
+ float scale, torch::Tensor &block_tables,
+ torch::Tensor &context_lens, int block_size,
+ int max_context_len,
+ const c10::optional &alibi_slopes,
+ const std::string &kv_cache_dtype, float kv_scale) {
+ TORCH_CHECK(kv_scale == 1.0f);
+ VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
+ [&] {
+ CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
+ CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
+ CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
+ });
+}
diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp
new file mode 100644
index 0000000000000..7849a5df991b1
--- /dev/null
+++ b/csrc/cpu/cache.cpp
@@ -0,0 +1,141 @@
+#include