diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh new file mode 100644 index 0000000000000..f187d1f181724 --- /dev/null +++ b/.buildkite/run-cpu-test.sh @@ -0,0 +1,14 @@ +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t cpu-test -f Dockerfile.cpu . + +# Setup cleanup +remove_docker_container() { docker rm -f cpu-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ee384c27e1d0c..27e44463a30a6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -34,7 +34,10 @@ steps: command: pytest -v -s engine tokenization test_sequence.py test_config.py - label: Entrypoints Test - command: pytest -v -s entrypoints + commands: + # these tests have to be separated, because each one will allocate all posible GPU memory + - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py + - pytest -v -s entrypoints/test_server_oot_registration.py - label: Examples Test working_dir: "/vllm-workspace/examples" @@ -90,7 +93,7 @@ steps: - bash run-benchmarks.sh - label: Documentation Build - working_dir: "/vllm-workspace/docs" + working_dir: "/vllm-workspace/test_docs/docs" no_gpu: True commands: - pip install -r requirements-docs.txt diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 4dde733581822..3ed23c62c005d 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -8,6 +8,9 @@ steps: queue: amd command: bash .buildkite/run-amd-test.sh + - label: "CPU Test" + command: bash .buildkite/run-cpu-test.sh + - label: ":docker: build image" commands: - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5211dc180798e..fc97e33c19af2 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -49,7 +49,7 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt. + pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt. cuda-version: ['11.8', '12.1'] steps: diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 2578d448436d2..60a3978f9abd7 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -9,12 +9,13 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH # Install requirements $python_executable -m pip install wheel packaging -$python_executable -m pip install -r requirements.txt +$python_executable -m pip install -r requirements-cuda.txt # Limit the number of parallel jobs to avoid OOM export MAX_JOBS=1 # Make sure punica is built for the release (for LoRA) export VLLM_INSTALL_PUNICA_KERNELS=1 - +# Make sure release wheels are built for the following architectures +export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" # Build $python_executable setup.py bdist_wheel --dist-dir=dist diff --git a/.gitignore b/.gitignore index b5195629e5cf3..b1513ef0ddb0c 100644 --- a/.gitignore +++ b/.gitignore @@ -181,6 +181,7 @@ _build/ # hip files generated by PyTorch *.hip *_hip* +hip_compat.h # Benchmark dataset *.json diff --git a/CMakeLists.txt b/CMakeLists.txt index 412b9c0cd59e0..1845151181284 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21) project(vllm_extensions LANGUAGES CXX) +option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda") + message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") +message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) @@ -16,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") # # Supported/expected torch versions for CUDA/ROCm. @@ -28,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100") # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2") +set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1") set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") @@ -76,6 +79,19 @@ find_package(Torch REQUIRED) find_library(torch_python_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") +# +# Forward the non-CUDA device extensions to external CMake scripts. +# +if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND + NOT VLLM_TARGET_DEVICE STREQUAL "rocm") + if (VLLM_TARGET_DEVICE STREQUAL "cpu") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) + else() + message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + endif() + return() +endif() + # # Set up GPU language and check the torch version and warn if it isn't # what is expected. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8db5e569b6aec..81a8db2b268b0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,7 +21,6 @@ Express your support on Twitter if vLLM aids you, or simply offer your appreciat ### Build from source ```bash -pip install -r requirements.txt pip install -e . # This may take several minutes. ``` @@ -30,6 +29,8 @@ pip install -e . # This may take several minutes. ```bash pip install -r requirements-dev.txt +# linting and formatting +bash format.sh # Static type checking mypy # Unit tests diff --git a/Dockerfile b/Dockerfile index f975530e09407..d1d29177b0f44 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,7 @@ # to run the OpenAI compatible server. #################### BASE BUILD IMAGE #################### +# prepare basic build environment FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev RUN apt-get update -y \ @@ -16,18 +17,26 @@ RUN ldconfig /usr/local/cuda-12.1/compat/ WORKDIR /workspace # install build and runtime dependencies -COPY requirements.txt requirements.txt +COPY requirements-common.txt requirements-common.txt +COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements.txt + pip install -r requirements-cuda.txt # install development dependencies COPY requirements-dev.txt requirements-dev.txt RUN --mount=type=cache,target=/root/.cache/pip \ pip install -r requirements-dev.txt + +# cuda arch list used by torch +# can be useful for both `dev` and `test` +# explicitly set the list to avoid issues with torch 2.2 +# see https://github.com/pytorch/pytorch/pull/123243 +ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' +ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} #################### BASE BUILD IMAGE #################### -#################### EXTENSION BUILD IMAGE #################### +#################### WHEEL BUILD IMAGE #################### FROM dev AS build # install build dependencies @@ -38,18 +47,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # install compiler cache to speed up compilation leveraging local or remote caching RUN apt-get update -y && apt-get install -y ccache -# copy input files +# files and directories related to build wheels COPY csrc csrc COPY setup.py setup.py COPY cmake cmake COPY CMakeLists.txt CMakeLists.txt -COPY requirements.txt requirements.txt +COPY requirements-common.txt requirements-common.txt +COPY requirements-cuda.txt requirements-cuda.txt COPY pyproject.toml pyproject.toml -COPY vllm/__init__.py vllm/__init__.py +COPY vllm vllm -# cuda arch list used by torch -ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' -ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} # max jobs used by Ninja to build extensions ARG max_jobs=2 ENV MAX_JOBS=${max_jobs} @@ -61,7 +68,15 @@ ENV VLLM_INSTALL_PUNICA_KERNELS=1 ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ - python3 setup.py build_ext --inplace + --mount=type=cache,target=/root/.cache/pip \ + python3 setup.py bdist_wheel --dist-dir=dist + +# the `vllm_nccl` package must be installed from source distribution +# pip is too smart to store a wheel in the cache, and other CI jobs +# will directly use the wheel from the cache, which is not what we want. +# we need to remove it manually +RUN --mount=type=cache,target=/root/.cache/pip \ + pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### #################### FLASH_ATTENTION Build IMAGE #################### @@ -81,57 +96,59 @@ RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ #################### FLASH_ATTENTION Build IMAGE #################### +#################### vLLM installation IMAGE #################### +# image with vLLM installed +FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base +WORKDIR /vllm-workspace + +RUN apt-get update -y \ + && apt-get install -y python3-pip git vim + +# Workaround for https://github.com/openai/triton/issues/2507 and +# https://github.com/pytorch/pytorch/issues/107960 -- hopefully +# this won't be needed for future versions of this docker image +# or future versions of triton. +RUN ldconfig /usr/local/cuda-12.1/compat/ + +# install vllm wheel first, so that torch etc will be installed +RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ + --mount=type=cache,target=/root/.cache/pip \ + pip install dist/*.whl --verbose + +RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ + --mount=type=cache,target=/root/.cache/pip \ + pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir +#################### vLLM installation IMAGE #################### + + #################### TEST IMAGE #################### # image to run unit testing suite -FROM dev AS test +# note that this uses vllm installed by `pip` +FROM vllm-base AS test -# copy pytorch extensions separately to avoid having to rebuild -# when python code changes -WORKDIR /vllm-workspace -# ADD is used to preserve directory structure ADD . /vllm-workspace/ -COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/ -# Install flash attention (from pre-built wheel) -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir -# ignore build dependencies installation because we are using pre-complied extensions -RUN rm pyproject.toml -RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose -#################### TEST IMAGE #################### - -#################### RUNTIME BASE IMAGE #################### -# We used base cuda image because pytorch installs its own cuda libraries. -# However pynccl depends on cuda libraries so we had to switch to the runtime image -# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda -FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base - -# libnccl required for ray -RUN apt-get update -y \ - && apt-get install -y python3-pip - -WORKDIR /workspace -COPY requirements.txt requirements.txt +# install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements.txt + pip install -r requirements-dev.txt -# Install flash attention (from pre-built wheel) -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir - -#################### RUNTIME BASE IMAGE #################### +# doc requires source code +# we hide them inside `test_docs/` , so that this source code +# will not be imported by other tests +RUN mkdir test_docs +RUN mv docs test_docs/ +RUN mv vllm test_docs/ +#################### TEST IMAGE #################### #################### OPENAI API SERVER #################### # openai api server alternative FROM vllm-base AS vllm-openai + # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate hf_transfer modelscope -COPY --from=build /workspace/vllm/*.so /workspace/vllm/ -COPY vllm vllm - ENV VLLM_USAGE_SOURCE production-docker-image ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.cpu b/Dockerfile.cpu new file mode 100644 index 0000000000000..4251fddd6cc3b --- /dev/null +++ b/Dockerfile.cpu @@ -0,0 +1,20 @@ +# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. + +FROM ubuntu:22.04 + +RUN apt-get update -y \ + && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ + && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +RUN pip install --upgrade pip \ + && pip install wheel packaging ninja setuptools>=49.4.0 numpy + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + +RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install + +CMD ["/bin/bash"] diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 65a367994f960..10b8bf1e7fabd 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH" # In that case, we need to use the python reference attention implementation in vllm ARG BUILD_FA="1" +# whether to build triton on rocm +ARG BUILD_TRITON="1" + # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y @@ -75,6 +78,17 @@ RUN if [ "$BUILD_FA" = "1" ]; then \ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi +# build triton +RUN if [ "$BUILD_TRITON" = "1" ]; then \ + mkdir -p libs \ + && cd libs \ + && pip uninstall -y triton \ + && git clone https://github.com/ROCm/triton.git \ + && cd triton/python \ + && pip3 install . \ + && cd ../..; \ + fi + COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip diff --git a/MANIFEST.in b/MANIFEST.in index aa16da6500e6c..d385f194c6c0f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ include LICENSE -include requirements.txt +include requirements-common.txt +include requirements-cuda.txt include CMakeLists.txt recursive-include cmake * diff --git a/README.md b/README.md index 08e46b68cb7ce..d53227b82d87a 100644 --- a/README.md +++ b/README.md @@ -14,18 +14,8 @@ Easy, fast, and cheap LLM serving for everyone

---- - -**The Third vLLM Bay Area Meetup (April 2nd 6pm-8:30pm PT)** - -We are thrilled to announce our third vLLM Meetup! -The vLLM team will share recent updates and roadmap. -We will also have vLLM collaborators from Roblox coming up to the stage to discuss their experience in deploying LLMs with vLLM. -Please register [here](https://robloxandvllmmeetup2024.splashthat.com/) and join us! - ---- - *Latest News* 🔥 +- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). - [2024/01] Added ROCm 6.0 support to vLLM. - [2023/12] Added ROCm 5.7 support to vLLM. @@ -80,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) - Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) +- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) @@ -88,7 +79,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) -- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.) +- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) - Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) - StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) - Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 96a372e5511b7..ad428bd1c3644 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -334,7 +334,8 @@ async def async_request_openai_chat_completions( timestamp = time.perf_counter() data = json.loads(chunk) - if "content" in data["choices"][0]["delta"]: + delta = data["choices"][0]["delta"] + if delta.get("content", None): # First token if ttft == 0: ttft = time.perf_counter() - st @@ -345,8 +346,7 @@ async def async_request_openai_chat_completions( output.itl.append(timestamp - most_recent_timestamp) - generated_text += data["choices"][0]["delta"][ - "content"] + generated_text += delta["content"] most_recent_timestamp = timestamp diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index da02493b17fd3..91510dafc57a5 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -24,6 +24,7 @@ def main(args: argparse.Namespace): dtype=args.dtype, enforce_eager=args.enforce_eager, kv_cache_dtype=args.kv_cache_dtype, + quantization_param_path=args.quantization_param_path, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, enable_chunked_prefill=args.enable_chunked_prefill, @@ -67,7 +68,8 @@ def main(args: argparse.Namespace): return latency print("Warming up...") - run_to_completion(profile_dir=None) + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + run_to_completion(profile_dir=None) if args.profile: profile_dir = args.profile_result_dir @@ -83,7 +85,12 @@ def main(args: argparse.Namespace): latencies = [] for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): latencies.append(run_to_completion(profile_dir=None)) + latencies = np.array(latencies) + percentages = [10, 25, 50, 75, 90] + percentiles = np.percentile(latencies, percentages) print(f'Avg latency: {np.mean(latencies)} seconds') + for percentage, percentile in zip(percentages, percentiles): + print(f'{percentage}% percentile latency: {percentile} seconds') if __name__ == '__main__': @@ -105,9 +112,13 @@ if __name__ == '__main__': default=1, help='Number of generated sequences per prompt.') parser.add_argument('--use-beam-search', action='store_true') + parser.add_argument('--num-iters-warmup', + type=int, + default=10, + help='Number of iterations to run for warmup.') parser.add_argument('--num-iters', type=int, - default=3, + default=30, help='Number of iterations to run.') parser.add_argument('--trust-remote-code', action='store_true', @@ -127,10 +138,23 @@ if __name__ == '__main__': parser.add_argument( "--kv-cache-dtype", type=str, - choices=['auto', 'fp8_e5m2'], + choices=['auto', 'fp8'], default='auto', help= - 'Data type for kv cache storage. If "auto", will use model data type.') + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') parser.add_argument( '--profile', action='store_true', @@ -145,8 +169,8 @@ if __name__ == '__main__': "--device", type=str, default="cuda", - choices=["cuda"], - help='device type for vLLM execution, supporting CUDA only currently.') + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument('--block-size', type=int, default=16, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index bc7812ed4119e..6054df439fa57 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -110,7 +110,9 @@ def sample_sonnet_requests( prefix_len: int, tokenizer: PreTrainedTokenizerBase, ) -> List[Tuple[str, str, int, int]]: - assert input_len > prefix_len, "input_len must be greater than prefix_len." + assert ( + input_len > prefix_len + ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." # Load the dataset. with open(dataset_path) as f: @@ -131,8 +133,9 @@ def sample_sonnet_requests( base_message, add_generation_prompt=True, tokenize=False) base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids) - assert (input_len > base_prompt_offset - ), f"Please set 'args.input-len' higher than {base_prompt_offset}." + assert ( + input_len > base_prompt_offset + ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}." num_input_lines = round( (input_len - base_prompt_offset) / average_poem_len) @@ -140,7 +143,7 @@ def sample_sonnet_requests( # prompt are fixed poem lines. assert ( prefix_len > base_prompt_offset - ), f"Please set 'args.prefix-len' higher than {base_prompt_offset}." + ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}." num_prefix_lines = round( (prefix_len - base_prompt_offset) / average_poem_len) @@ -373,9 +376,9 @@ def main(args: argparse.Namespace): input_requests = sample_sonnet_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, - input_len=args.input_len, - output_len=args.output_len, - prefix_len=args.prefix_len, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) input_requests = [(prompt, prompt_len, output_len) @@ -388,9 +391,9 @@ def main(args: argparse.Namespace): input_requests = sample_sonnet_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, - input_len=args.input_len, - output_len=args.output_len, - prefix_len=args.prefix_len, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) input_requests = [(prompt_formatted, prompt_len, output_len) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 9d84bde17d6d0..e71338273d1e5 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -29,22 +29,23 @@ def sample_requests( dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - if fixed_output_len is not None: - output_len = fixed_output_len - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + # Shuffle the dataset. + random.shuffle(dataset) - # Filter out too long sequences. + # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue @@ -53,9 +54,7 @@ def sample_requests( continue filtered_dataset.append((prompt, prompt_len, output_len)) - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests + return filtered_dataset def run_vllm( @@ -72,6 +71,7 @@ def run_vllm( max_model_len: Optional[int], enforce_eager: bool, kv_cache_dtype: str, + quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, gpu_memory_utilization: float = 0.9, @@ -89,6 +89,7 @@ def run_vllm( gpu_memory_utilization=gpu_memory_utilization, enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, device=device, enable_prefix_caching=enable_prefix_caching, download_dir=download_dir) @@ -217,7 +218,8 @@ def main(args: argparse.Namespace): args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, args.device, + args.kv_cache_dtype, + args.quantization_param_path, args.device, args.enable_prefix_caching, args.gpu_memory_utilization, args.download_dir) elif args.backend == "hf": @@ -306,16 +308,29 @@ if __name__ == "__main__": parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8_e5m2"], + choices=["auto", "fp8"], default="auto", help= - 'Data type for kv cache storage. If "auto", will use model data type.') + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') parser.add_argument( "--device", type=str, default="cuda", - choices=["cuda"], - help='device type for vLLM execution, supporting CUDA only currently.') + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument( "--enable-prefix-caching", action='store_true', diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index f6c8f900a3bff..f71d1fcaaef50 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -97,6 +97,9 @@ def main( torch.cuda.cudart().cudaProfilerStart() start_time = time.perf_counter() + # Using default kv_scale + kv_scale = 1.0 + for _ in range(num_iters): if version == "v1": ops.paged_attention_v1( @@ -112,6 +115,7 @@ def main( max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) elif version == "v2": ops.paged_attention_v2( @@ -130,6 +134,7 @@ def main( max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) else: raise ValueError(f"Invalid version: {version}") @@ -179,11 +184,13 @@ if __name__ == '__main__': parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8_e5m2"], + choices=["auto", "fp8"], default="auto", help= - 'Data type for kv cache storage. If "auto", will use model data type.') - parser.add_argument("--device", type=str, choices=["cuda"], default="cuda") + 'Data type for kv cache storage. If "auto", will use model data type. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' + 'common inference criteria.') args = parser.parse_args() print(args) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake new file mode 100644 index 0000000000000..0cf37769a6960 --- /dev/null +++ b/cmake/cpu_extension.cmake @@ -0,0 +1,90 @@ +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# +# Define environment variables for special configurations +# +if(DEFINED ENV{VLLM_CPU_AVX512BF16}) + set(ENABLE_AVX512BF16 ON) +endif() + +include_directories("${CMAKE_SOURCE_DIR}/csrc") + +# +# Check the compile flags +# +list(APPEND CXX_COMPILE_FLAGS + "-fopenmp" + "-DVLLM_CPU_EXTENSION") + +execute_process(COMMAND cat /proc/cpuinfo + RESULT_VARIABLE CPUINFO_RET + OUTPUT_VARIABLE CPUINFO) + +if (NOT CPUINFO_RET EQUAL 0) + message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") +endif() + +function (find_isa CPUINFO TARGET OUT) + string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) + if(NOT ISA_FOUND EQUAL -1) + set(${OUT} ON PARENT_SCOPE) + else() + set(${OUT} OFF PARENT_SCOPE) + endif() +endfunction() + +find_isa(${CPUINFO} "avx512f" AVX512_FOUND) + +if (AVX512_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-mavx512f" + "-mavx512vl" + "-mavx512bw" + "-mavx512dq") + + find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) + if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + else() + message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") + endif() + else() + message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") + endif() +else() + message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.") +endif() + +message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") + + +# +# Define extension targets +# + +# +# _C extension +# +set(VLLM_EXT_SRC + "csrc/cpu/activation.cpp" + "csrc/cpu/attention.cpp" + "csrc/cpu/cache.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/cpu/pybind.cpp") + +define_gpu_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + WITH_SOABI +) + +add_custom_target(default) +message(STATUS "Enabling C extension.") +add_dependencies(default _C) + diff --git a/cmake/utils.cmake b/cmake/utils.cmake index c7d3d85389838..7c71673e36f29 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -100,6 +100,8 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + endif() + if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__" @@ -117,6 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" + "-DENABLE_FP8_E4M3" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index 61748e6b1eee6..64f86381d9db9 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -4,4 +4,4 @@ #include "dtype_float16.cuh" #include "dtype_float32.cuh" #include "dtype_bfloat16.cuh" -#include "dtype_fp8_e5m2.cuh" +#include "dtype_fp8.cuh" diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 5e61668d5cc1a..f3a5bbfd3098d 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -22,12 +22,26 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" -#ifdef ENABLE_FP8_E5M2 + +#if defined(ENABLE_FP8_E5M2) #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" +#elif defined(ENABLE_FP8_E4M3) +#include "../quantization/fp8/amd_detail/quant_utils.cuh" #endif #include +#ifdef USE_ROCM + #include + typedef __hip_bfloat16 __nv_bfloat16; +#endif + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#else +#define WARP_SIZE warpSize +#endif + #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -78,7 +92,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_E5M2_KV_CACHE, + bool IS_FP8_KV_CACHE, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -95,7 +109,8 @@ __device__ void paged_attention_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { + const int kv_head_stride, + const float kv_scale) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -142,7 +157,7 @@ __device__ void paged_attention_kernel( constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; #endif @@ -208,11 +223,16 @@ __device__ void paged_attention_kernel( const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - if constexpr (IS_FP8_E5M2_KV_CACHE) { -#ifdef ENABLE_FP8_E5M2 + if constexpr (IS_FP8_KV_CACHE) { +#if defined(ENABLE_FP8_E5M2) Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); // Vector conversion from Quant_vec to K_vec. k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); +#elif defined(ENABLE_FP8_E4M3) + Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k + // cache vec to k vec in higher precision (FP16, BFloat16, etc.) + k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); #else assert(false); #endif @@ -292,7 +312,7 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; #endif using Float_L_vec = typename FloatVec::Type; @@ -328,11 +348,16 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - if constexpr (IS_FP8_E5M2_KV_CACHE) { -#ifdef ENABLE_FP8_E5M2 + if constexpr (IS_FP8_KV_CACHE) { +#if defined(ENABLE_FP8_E5M2) V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); +#elif defined(ENABLE_FP8_E4M3) + V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert + // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) + v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); #else assert(false); #endif @@ -423,7 +448,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_E5M2_KV_CACHE> + bool IS_FP8_KV_CACHE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -437,11 +462,12 @@ __global__ void paged_attention_v1_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( + const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs, max_num_partitions). @@ -451,7 +477,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_E5M2_KV_CACHE, + bool IS_FP8_KV_CACHE, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -468,11 +494,12 @@ __global__ void paged_attention_v2_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( + const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); + q_stride, kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs). @@ -579,9 +606,9 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + IS_FP8_KV_CACHE>), shared_mem_size); \ vllm::paged_attention_v1_kernel<<>>( \ + IS_FP8_KV_CACHE><<>>( \ out_ptr, \ query_ptr, \ key_cache_ptr, \ @@ -594,14 +621,15 @@ __global__ void paged_attention_v2_reduce_kernel( alibi_slopes_ptr, \ q_stride, \ kv_block_stride, \ - kv_head_stride); + kv_head_stride, \ + kv_scale); // TODO(woosuk): Tune NUM_THREADS. template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_E5M2_KV_CACHE, + bool IS_FP8_KV_CACHE, int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, @@ -613,7 +641,8 @@ void paged_attention_v1_launcher( torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes, + float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -677,8 +706,8 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \ - paged_attention_v1_launcher( \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v1_launcher( \ out, \ query, \ key_cache, \ @@ -688,20 +717,21 @@ void paged_attention_v1_launcher( block_tables, \ context_lens, \ max_context_len, \ - alibi_slopes); + alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ switch (block_size) { \ case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ break; \ case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ break; \ case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -720,7 +750,8 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { + const std::string& kv_cache_dtype, + float kv_scale) { if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Float) { CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); @@ -731,7 +762,7 @@ void paged_attention_v1( } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } - } else if (kv_cache_dtype == "fp8_e5m2") { + } else if (kv_cache_dtype == "fp8") { if (query.dtype() == at::ScalarType::Float) { CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); } else if (query.dtype() == at::ScalarType::Half) { @@ -748,7 +779,7 @@ void paged_attention_v1( #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ vllm::paged_attention_v2_kernel \ + IS_FP8_KV_CACHE, PARTITION_SIZE> \ <<>>( \ exp_sums_ptr, \ max_logits_ptr, \ @@ -764,7 +795,8 @@ void paged_attention_v1( alibi_slopes_ptr, \ q_stride, \ kv_block_stride, \ - kv_head_stride); \ + kv_head_stride, \ + kv_scale); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ out_ptr, \ @@ -778,7 +810,7 @@ template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_E5M2_KV_CACHE, + bool IS_FP8_KV_CACHE, int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( @@ -794,7 +826,8 @@ void paged_attention_v2_launcher( torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes, + float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -864,8 +897,8 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \ - paged_attention_v2_launcher( \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v2_launcher( \ out, \ exp_sums, \ max_logits, \ @@ -878,20 +911,21 @@ void paged_attention_v2_launcher( block_tables, \ context_lens, \ max_context_len, \ - alibi_slopes); + alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ switch (block_size) { \ case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ break; \ case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ break; \ case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -913,7 +947,8 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { + const std::string& kv_cache_dtype, + float kv_scale) { if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Float) { CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); @@ -924,7 +959,7 @@ void paged_attention_v2( } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } - } else if (kv_cache_dtype == "fp8_e5m2") { + } else if (kv_cache_dtype == "fp8") { if (query.dtype() == at::ScalarType::Float) { CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); } else if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/attention/dtype_fp8_e5m2.cuh b/csrc/attention/dtype_fp8.cuh similarity index 89% rename from csrc/attention/dtype_fp8_e5m2.cuh rename to csrc/attention/dtype_fp8.cuh index 0580fbb8e863f..d11dee91ebe87 100644 --- a/csrc/attention/dtype_fp8_e5m2.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -8,7 +8,7 @@ #endif namespace vllm { -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) // fp8 vector types for quantization of kv cache template<> diff --git a/csrc/cache.h b/csrc/cache.h index 765e231abd26f..718a5f6cfd7f7 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -21,9 +21,10 @@ void reshape_and_cache( torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + const float kv_scale); // Just for unittest -void convert_fp8_e5m2( +void convert_fp8( torch::Tensor& src_cache, torch::Tensor& dst_cache); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7254010b8e3a9..24aaa2ff3e263 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,8 +4,10 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" +#elif defined(ENABLE_FP8_E4M3) +#include "quantization/fp8/amd_detail/quant_utils.cuh" #endif #include @@ -151,7 +153,7 @@ void copy_blocks( namespace vllm { -template +template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] @@ -163,7 +165,8 @@ __global__ void reshape_and_cache_kernel( const int num_heads, const int head_size, const int block_size, - const int x) { + const int x, + const float kv_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -195,10 +198,13 @@ __global__ void reshape_and_cache_kernel( + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (is_fp8_e5m2_kv_cache) { -#ifdef ENABLE_FP8_E5M2 + if constexpr (is_fp8_kv_cache) { +#if defined(ENABLE_FP8_E5M2) key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); +#elif defined(ENABLE_FP8_E4M3) + key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); + value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); #else assert(false); #endif @@ -211,8 +217,8 @@ __global__ void reshape_and_cache_kernel( } // namespace vllm -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ - vllm::reshape_and_cache_kernel<<>>( \ +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ + vllm::reshape_and_cache_kernel<<>>( \ reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast(key_cache.data_ptr()), \ @@ -223,7 +229,8 @@ __global__ void reshape_and_cache_kernel( num_heads, \ head_size, \ block_size, \ - x); + x, \ + kv_scale); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -231,7 +238,8 @@ void reshape_and_cache( torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) + const std::string& kv_cache_dtype, + const float kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); @@ -254,7 +262,7 @@ void reshape_and_cache( } else if (key.dtype() == at::ScalarType::BFloat16) { CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); } - } else if (kv_cache_dtype == "fp8_e5m2") { + } else if (kv_cache_dtype == "fp8") { if (key.dtype() == at::ScalarType::Float) { CALL_RESHAPE_AND_CACHE(float, uint8_t, true); } else if (key.dtype() == at::ScalarType::Half) { @@ -270,15 +278,17 @@ void reshape_and_cache( namespace vllm { template -__global__ void convert_fp8_e5m2_kernel( +__global__ void convert_fp8_kernel( const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; -#ifdef ENABLE_FP8_E5M2 +#if defined(ENABLE_FP8_E5M2) dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); +#elif defined(ENABLE_FP8_E4M3) + dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); #else assert(false); #endif @@ -287,16 +297,25 @@ __global__ void convert_fp8_e5m2_kernel( } // namespace vllm -#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \ - vllm::convert_fp8_e5m2_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ +#define CALL_CONVERT_FP8(Tout, Tin) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ block_stride); -void convert_fp8_e5m2( +void convert_fp8( torch::Tensor& src_cache, torch::Tensor& dst_cache) { + torch::Device src_device = src_cache.device(); + torch::Device dst_device = dst_cache.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK( + src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + int64_t num_blocks = src_cache.size(0); int64_t block_stride = src_cache.stride(0); @@ -305,16 +324,16 @@ void convert_fp8_e5m2( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8_E5M2(uint8_t, float); + CALL_CONVERT_FP8(uint8_t, float); } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t); + CALL_CONVERT_FP8(uint8_t, uint16_t); } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16); + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8_E5M2(float, uint8_t); + CALL_CONVERT_FP8(float, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t); + CALL_CONVERT_FP8(uint16_t, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t); + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); } } diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp new file mode 100644 index 0000000000000..1bd24eb79d129 --- /dev/null +++ b/csrc/cpu/activation.cpp @@ -0,0 +1,148 @@ +#include "cpu_types.hpp" + +namespace { +template +void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, + scalar_t *__restrict__ output) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + + TORCH_CHECK(d % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + for (int j = 0; j < d; j += VEC_ELEM_NUM) { + int start = i * d; + if constexpr (is_gated) { + start *= 2; + } + + const scalar_vec_t x(input + start + j); + const vec_op::FP32Vec8 f32_x(x); + vec_op::FP32Vec8 f32_ans = func(f32_x); + + if constexpr (is_gated) { + const scalar_vec_t y(input + start + d + j); + const vec_op::FP32Vec8 f32_y(y); + f32_ans = f32_y * f32_ans; + } + + const scalar_vec_t result(f32_ans); + result.save(output + i * d + j); + } + } +} + +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + return x / (ones + (zeros - x).exp()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 x3 = x * x * x; + const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT1_2); + const vec_op::FP32Vec8 w2(0.5); + return x * w2 * (ones + (x * w1).er()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); + const vec_op::FP32Vec8 w2(0.5); + const vec_op::FP32Vec8 w3(0.044715); + const vec_op::FP32Vec8 x_3 = x * x * x; + const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); + return x * w2 * (ones + inner.tanh()); +} +}; // namespace + +void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel(num_tokens, d, + input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); +} + +void gelu_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel(num_tokens, d, + input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); +} + +void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_tanh_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl) + }); +} + +void gelu_new(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_new_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_new_impl) + }); +} + +void gelu_fast(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_fast_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_fast_impl) + }); +} diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp new file mode 100644 index 0000000000000..365bbd5e23728 --- /dev/null +++ b/csrc/cpu/attention.cpp @@ -0,0 +1,746 @@ +#include "cpu_types.hpp" + +namespace { + +template struct KernelVecType { + using q_load_vec_type = void; + using q_vec_type = void; + using k_load_vec_type = void; + using k_vec_type = void; + using qk_acc_vec_type = void; + using v_load_vec_type = void; +}; + +template <> struct KernelVecType { + using q_load_vec_type = vec_op::FP32Vec4; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::FP32Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP32Vec16; +}; + +#ifdef __AVX512BF16__ +template <> struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::BF16Vec32; + using k_load_vec_type = vec_op::BF16Vec32; + using k_vec_type = vec_op::BF16Vec32; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#else +template <> struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::BF16Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#endif + +template +FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, + const int capacity) { + T max = data[0]; + for (int i = 1; i < size; ++i) { + max = max >= data[i] ? max : data[i]; + } + + T sum = 0; + for (int i = 0; i < size; ++i) { + data[i] = std::exp(data[i] - max); + sum += data[i]; + } + + int i = 0; + for (; i < size; ++i) { + data[i] /= sum; + } + + for (; i < capacity; ++i) { + data[i] = 0; + } + + return {max, sum}; +} + +template +FORCE_INLINE std::pair +reduceSoftmaxAlibi(T *data, const int size, const int capacity, + const float alibi_slope, const int start_index, + const int context_len) { + data[0] += alibi_slope * (start_index - context_len + 1); + T max = data[0]; + for (int i = 1; i < size; ++i) { + T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); + data[i] = qk; + max = max >= qk ? max : qk; + } + + T sum = 0; + for (int i = 0; i < size; ++i) { + data[i] = std::exp(data[i] - max); + sum += data[i]; + } + + int i = 0; + for (; i < size; ++i) { + data[i] /= sum; + } + + for (; i < capacity; ++i) { + data[i] = 0; + } + + return {max, sum}; +} + +template +FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, + const int size) { + T max = max_data[0]; + for (int i = 1; i < size; ++i) { + max = max >= max_data[i] ? max : max_data[i]; + } + + T rescaled_sum = 0; + for (int i = 0; i < size; ++i) { + T rescale_factor = std::exp(max_data[i] - max); + rescaled_sum += rescale_factor * sum_data[i]; + sum_data[i] *= rescale_factor; + } + for (int i = 0; i < size; ++i) { + sum_data[i] /= rescaled_sum + 1e-8; + } +} + +template +struct reduceQKBlockKernel { + using q_load_vec_type = typename KernelVecType::q_load_vec_type; + using q_vec_type = typename KernelVecType::q_vec_type; + using k_load_vec_type = typename KernelVecType::k_load_vec_type; + using k_vec_type = typename KernelVecType::k_vec_type; + using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; + + constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; + constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; + constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; + + static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); + static_assert(k_load_vec_type::get_elem_num() % x == 0); + static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); + + FORCE_INLINE static void call(const scalar_t *__restrict__ q, + const scalar_t *__restrict__ k_block, + float *__restrict__ logits, float scale, + const int token_num) { + const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; + + qk_acc_vec_type group_accums[MAX_GROUP_NUM]; + if (token_num == BLOCK_SIZE) { + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block += x * BLOCK_SIZE) { + q_load_vec_type q_load_group_vec(q + q_offset); + q_vec_type q_group_vec(q_load_group_vec); + + vec_op::unroll_loop( + [k_block, &q_group_vec, &group_accums](int token_group_idx) { + k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * + TOKEN_PER_GROUP); + k_vec_type k_group_vec(k_load_group_vec); + vec_op::fma(group_accums[token_group_idx], q_group_vec, + k_group_vec); + vec_op::prefetch(k_block + x * BLOCK_SIZE + + token_group_idx * x * TOKEN_PER_GROUP); + }); + } + } else { + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block += x * BLOCK_SIZE) { + q_load_vec_type q_load_group_vec(q + q_offset); + q_vec_type q_group_vec(q_load_group_vec); + for (int token_group_start = 0; token_group_start < group_num; + token_group_start += UNROLL_GROUP_NUM) { + vec_op::unroll_loop( + [token_group_start, k_block, &q_group_vec, + &group_accums](int token_group_idx) { + token_group_idx += token_group_start; + k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * + TOKEN_PER_GROUP); + k_vec_type k_group_vec(k_load_group_vec); + vec_op::fma(group_accums[token_group_idx], q_group_vec, + k_group_vec); + vec_op::prefetch(k_block + x * BLOCK_SIZE + + token_group_idx * x * TOKEN_PER_GROUP); + }); + } + } + } + + for (int token_group_idx = 0; token_group_idx < group_num; + ++token_group_idx) { + vec_op::unroll_loop( + [&group_accums, logits, scale, token_group_idx](int token_idx) { + float dot_v = + group_accums[token_group_idx] + .template reduce_sub_sum(token_idx); + logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = + dot_v * scale; + }); + } + } +}; + +template +FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, + acc_t &&acc) { + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); + static_assert(BLOCK_SIZE == ELEM_NUM); + vec_op::FP32Vec16 prob_vec(prob); + + vec_op::unroll_loop([&](int head_elem_idx) { + v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx); + vec_op::FP32Vec16 fp32_v_vec(v_vec); + acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; + }); +} +}; // namespace + +// Paged attention v1 +namespace { +template +struct paged_attention_v1_impl { + static void + call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + static_assert(BLOCK_SIZE == 16); + + int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + + const int parallel_work_item_num = omp_get_max_threads(); + + size_t logits_bytes = + parallel_work_item_num * max_context_len_padded * sizeof(float); + float *logits = (float *)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_context_len_padded] + +#pragma omp parallel for collapse(2) schedule(dynamic, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + int context_len = context_lens[seq_idx]; + const int *seq_block_table = + block_tables + max_num_blocks_per_seq * seq_idx; + const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t *__restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + const int last_block_token_num = + context_len - (block_num - 1) * BLOCK_SIZE; + float *__restrict__ thread_block_logits = + logits + omp_get_thread_num() * max_context_len_padded; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t *__restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float *__restrict__ head_block_logits = + thread_block_logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( + q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, + block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + } + + // Compute softmax + if (alibi_slopes) { + reduceSoftmaxAlibi(thread_block_logits, context_len, + block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, + context_len); + } else { + reduceSoftmax(thread_block_logits, context_len, + block_num * BLOCK_SIZE); + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t *__restrict__ out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float *__restrict__ prob_vec_ptr = + thread_block_logits + block_idx * BLOCK_SIZE; + const scalar_t *__restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + reduceValueBlock( + prob_vec_ptr, v_block_cache_ptr, accums); + + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; + const scalar_t *__restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + vec_op::unroll_loop( + [&](int head_elem_idx) { + if (head_elem_idx % 2 == 0) { + vec_op::prefetch(next_v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + } + }); + } + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32(value, out_ptr + head_elem_idx); + }); + } + } + } + std::free(logits); + } +}; + +#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v1_impl::call( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ + num_heads); + +template +void paged_attention_v1_impl_launcher( + torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, + int max_context_len, const c10::optional &alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float *alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T *out_ptr = reinterpret_cast(out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *context_lens_ptr = context_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + context_lens, max_context_len, alibi_slopes); + +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, + torch::Tensor &key_cache, torch::Tensor &value_cache, + int num_kv_heads, float scale, + torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype, float kv_scale) { + TORCH_CHECK(kv_scale == 1.0f); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) + CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) + }); +} + +// Paged attention v2 +namespace { +template +struct paged_attention_v2_impl { + static void call( + scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float + *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads, const int max_num_partitions) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + static_assert(BLOCK_SIZE == 16); + static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0); + static_assert(PARTITION_SIZE % BLOCK_SIZE == 0); + +#pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int partition_idx = 0; partition_idx < max_num_partitions; + ++partition_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int context_len = context_lens[seq_idx]; + const int start_token_idx = partition_idx * PARTITION_SIZE; + + if (start_token_idx >= context_len) + continue; + + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + const bool no_reduce = (partition_num == 1); + const int context_token_num = + (std::min(context_len, start_token_idx + PARTITION_SIZE) - + start_token_idx); + const int block_num = + (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int last_block_token_num = + context_token_num - (block_num - 1) * BLOCK_SIZE; + const int *seq_block_table = block_tables + + max_num_blocks_per_seq * seq_idx + + start_token_idx / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t *__restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + + float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t *__restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float *__restrict__ head_block_logits = + logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( + q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, + block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + } + + std::pair max_and_sum; + if (alibi_slopes) { + max_and_sum = reduceSoftmaxAlibi( + logits, context_token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, context_len); + } else { + max_and_sum = reduceSoftmax(logits, context_token_num, + block_num * BLOCK_SIZE); + } + + auto &&[max_logit, exp_sum] = max_and_sum; + + scalar_t *__restrict__ output_buffer = nullptr; + if (!no_reduce) { + auto idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + max_logits[idx] = max_logit; + exp_sums[idx] = exp_sum; + output_buffer = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + output_buffer = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t *__restrict__ out_ptr = + output_buffer + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float *__restrict__ prob_vec_ptr = + logits + block_idx * BLOCK_SIZE; + const scalar_t *__restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + reduceValueBlock( + prob_vec_ptr, v_block_cache_ptr, accums); + + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; + const scalar_t *__restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + vec_op::unroll_loop( + [&](int head_elem_idx) { + if (head_elem_idx % 2 == 0) { + vec_op::prefetch(next_v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + } + }); + } + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32(value, out_ptr + head_elem_idx); + }); + } + } + } + } + + // Rescale partition softmax and store the factors to exp_sums +#pragma omp parallel for collapse(2) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int context_len = context_lens[seq_idx]; + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + if (partition_num == 1) + continue; + + reducePartitonSoftmax( + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions, + exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions, + partition_num); + } + } + + // Reduce values + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); + constexpr int head_elem_num_per_group = + 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE + // didn't align with 64 bytes + static_assert(HEAD_SIZE % head_elem_num_per_group == 0); + constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; + const float *__restrict__ rescale_factors = exp_sums; +#pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { + const int context_len = context_lens[seq_idx]; + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + if (partition_num == 1) + continue; + + const float *__restrict__ seq_head_rescale_factors = + rescale_factors + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + const scalar_t *__restrict__ seq_head_tmp_out = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + group_idx * head_elem_num_per_group; + scalar_t *__restrict__ seq_head_output = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + group_idx * head_elem_num_per_group; + + vec_op::FP32Vec16 acc; + for (int i = 0; i < partition_num; ++i) { + vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]); + v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE); + vec_op::FP32Vec16 fp32_value(value); + acc = acc + fp32_value * rescale_factor; + } + v_load_vec_type cast_acc(acc); + cast_acc.save(seq_head_output); + } + } + } + } +}; + +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ + max_num_partitions); + +template +void paged_attention_v2_impl_launcher( + torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, + torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, + int max_context_len, const c10::optional &alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + int max_num_partitions = exp_sums.size(-1); + + // NOTE: alibi_slopes is optional. + const float *alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T *out_ptr = reinterpret_cast(out.data_ptr()); + float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *context_lens_ptr = context_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, block_size, \ + max_context_len, alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, + torch::Tensor &max_logits, torch::Tensor &tmp_out, + torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, + float scale, torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype, float kv_scale) { + TORCH_CHECK(kv_scale == 1.0f); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) + CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) + }); +} diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp new file mode 100644 index 0000000000000..7849a5df991b1 --- /dev/null +++ b/csrc/cpu/cache.cpp @@ -0,0 +1,141 @@ +#include +#include + +#include "cpu_types.hpp" + +namespace { +template +void copy_blocks_cpu_impl( + std::vector &key_caches, + std::vector &value_caches, + const std::vector> mapping_pairs, + const int element_num_per_block, const int layer_num) { + const size_t pair_num = mapping_pairs.size(); + const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; +#pragma omp parallel for collapse(2) + for (int layer = 0; layer < layer_num; ++layer) { + for (size_t pair = 0; pair < pair_num; ++pair) { + int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t target_offset = + element_num_per_block * mapping_pairs[pair].second; + scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t *source_ptr = key_cache_ptr + source_offset; + scalar_t *target_ptr = key_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + + scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + source_ptr = value_cache_ptr + source_offset; + target_ptr = value_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + } + } +} + +template +void reshape_and_cache_cpu_impl( + const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, + scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, + const int64_t *__restrict__ slot_mapping, const int num_tokens, + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x) { + const int block_elem_num = num_heads * head_size * block_size; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx >= 0) { + int src_key_head_idx = token_idx * key_stride + head_idx * head_size; + int src_value_head_idx = + token_idx * value_stride + head_idx * head_size; + const scalar_t *src_key_head_ptr = key + src_key_head_idx; + const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const int64_t block_index = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + scalar_t *target_key_head_ptr = key_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + scalar_t *target_value_head_ptr = value_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + + for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { + const int64_t target_offset = + src_key_idx * block_size + block_offset * x; + for (int i = 0; i < x; ++i) { + target_key_head_ptr[target_offset + i] = + src_key_head_ptr[src_key_idx + i]; + } + } + + for (int src_value_idx = 0; src_value_idx < head_size; + ++src_value_idx) { + const int64_t target_offset = + src_value_idx * block_size + block_offset; + target_value_head_ptr[target_offset] = + src_value_head_ptr[src_value_idx]; + } + } + } + } +} +}; // namespace + +void copy_blocks(std::vector &key_caches, + std::vector &value_caches, + const std::map> &block_mapping) { + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + + std::vector> mapping_pairs; + mapping_pairs.reserve(block_mapping.size()); + for (const auto &pair : block_mapping) { + for (const auto &dst : pair.second) { + mapping_pairs.emplace_back(pair.first, dst); + } + } + + const int element_num_per_block = key_caches[0][0].numel(); + VLLM_DISPATCH_FLOATING_TYPES( + key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) + copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + element_num_per_block, num_layers); + CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) + }); +} + +void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping, + const std::string &kv_cache_dtype, float kv_scale) { + TORCH_CHECK(kv_scale == 1.0f); + + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) + reshape_and_cache_cpu_impl( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), num_tokens, key_stride, + value_stride, num_heads, head_size, block_size, x); + CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) + }); +} + +void swap_blocks(torch::Tensor &src, torch::Tensor &dst, + const std::map &block_mapping) { + TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") +} diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp new file mode 100644 index 0000000000000..c1d3ec058b991 --- /dev/null +++ b/csrc/cpu/cpu_types.hpp @@ -0,0 +1,352 @@ + +#ifndef CPU_TYPES_HPP +#define CPU_TYPES_HPP + +#include +#include + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +#ifdef __AVX512FP16__ +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128h reg; + + explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + + explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + + explicit FP16Vec8(__m128h data) : reg(data) {} + + FP16Vec8 operator*(const FP16Vec8 &b) const { + return FP16Vec8(_mm_mul_ph(reg, b.reg)); + } + + FP16Vec8 operator+(const FP16Vec8 &b) const { + return FP16Vec8(_mm_add_ph(reg, b.reg)); + } + + FP16Vec8 operator-(const FP16Vec8 &b) const { + return FP16Vec8(_mm_sub_ph(reg, b.reg)); + } + + FP16Vec8 operator/(const FP16Vec8 &b) const { + return FP16Vec8(_mm_div_ph(reg, b.reg)); + } + + void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +}; +#endif + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m512i reg; + + explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + + explicit BF16Vec32(__m512i data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512i)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + +#ifdef __AVX512FP16__ + explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +#endif + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg((__m512)_mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), + (__m128i)data.reg, 1), + (__m128i)data.reg, 2), + (__m128i)data.reg, 3)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg((__m512)_mm512_inserti32x8( + _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + : reg(_mm512_castsi512_ps( + _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_sub_ps(reg, b.reg)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_div_ps(reg, b.reg)); + } + + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); + return _mm512_mask_reduce_add_ps(mask, reg); + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +#ifdef __AVX512FP16__ +template <> struct VecType { using vec_type = FP16Vec16; }; +#endif + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +#ifdef __AVX512FP16__ +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast<_Float16 *>(ptr) = v; +} +#endif + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +#ifdef __AVX512BF16__ +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} + +inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { + acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); +} +#else +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtepi32_epi16( + _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtepi32_epi16( + _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} +#endif + +inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + +}; // namespace vec_op + +#endif diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp new file mode 100644 index 0000000000000..467f0dc84982c --- /dev/null +++ b/csrc/cpu/layernorm.cpp @@ -0,0 +1,117 @@ +#include "cpu_types.hpp" + +namespace { +template +void rms_norm_impl(scalar_t *__restrict__ out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto output_p = out + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + vec_op::FP32Vec8 fp32_x(x); + variance = variance + fp32_x * fp32_x; + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t w(weight + j); + + vec_op::FP32Vec8 fp32_x(x); + vec_op::FP32Vec8 fp32_w(w); + + vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(output_p + j); + } + } +} + +template +void fused_add_rms_norm_impl(scalar_t *__restrict__ input, + scalar_t *__restrict__ residual, + const scalar_t *__restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto residual_p = residual + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t res(residual_p + j); + vec_op::FP32Vec8 fp32_x(x); + vec_op::FP32Vec8 fp32_res(res); + + fp32_x = fp32_x + fp32_res; + variance = variance + fp32_x * fp32_x; + scalar_vec_t out(fp32_x); + out.save(residual_p + j); + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t w(weight + j); + scalar_vec_t res(residual_p + j); + + vec_op::FP32Vec8 fp32_w(w); + vec_op::FP32Vec8 fp32_res(res); + + vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(input_p + j); + } + } +} +} // namespace + +void rms_norm(torch::Tensor &out, torch::Tensor &input, + torch::Tensor &weight, float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(rms_norm_impl) + rms_norm_impl(out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_impl) + }); +} + +void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, + torch::Tensor &weight, float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "fused_add_rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl) + fused_add_rms_norm_impl( + input.data_ptr(), residual.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl) + }); +} diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp new file mode 100644 index 0000000000000..e9b3992204bb2 --- /dev/null +++ b/csrc/cpu/pos_encoding.cpp @@ -0,0 +1,199 @@ + +#include "cpu_types.hpp" + +namespace { +template +void rotary_embedding_impl( + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + constexpr int ELEM_SIZE = sizeof(scalar_t); + + const int embed_dim = rot_dim / 2; + TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t q_x(query + out_x); + const scalar_vec_t q_y(query + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); + + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(query + out_x); + + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(query + out_y); + } + } + + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t k_x(key + out_x); + const scalar_vec_t k_y(key + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_k_x(k_x); + vec_op::FP32Vec8 fp32_k_y(k_y); + + auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; + scalar_vec_t(out1).save(key + out_x); + auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; + scalar_vec_t(out2).save(key + out_y); + } + } + } +} + +template +void rotary_embedding_gptj_impl( + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + const int embed_dim = rot_dim / 2; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_heads; ++i) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t *cos_cache_ptr = cache_ptr; + const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + scalar_t *head_query = token_head + query; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; + const int y_index = 2 * rot_offset + 1; + + const float cos = cos_cache_ptr[rot_offset]; + const float sin = sin_cache_ptr[rot_offset]; + + const float x = head_query[x_index]; + const float y = head_query[y_index]; + + head_query[x_index] = x * cos - y * sin; + head_query[y_index] = y * cos + x * sin; + } + } + } + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_kv_heads; ++i) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t *cos_cache_ptr = cache_ptr; + const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + scalar_t *head_key = key + token_head; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; + const int y_index = 2 * rot_offset + 1; + + const float cos = cos_cache_ptr[rot_offset]; + const float sin = sin_cache_ptr[rot_offset]; + + const float x = head_key[x_index]; + const float y = head_key[y_index]; + + head_key[x_index] = x * cos - y * sin; + head_key[y_index] = y * cos + x * sin; + } + } + } +} +}; // namespace + +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox) { + int num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t key_stride = key.stride(-2); + int64_t query_stride = query.stride(-2); + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "rotary_embedding_impl", [&] { + CPU_KERNEL_GUARD_IN(rotary_embedding_impl) + if (is_neox) { + rotary_embedding_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + } else { + rotary_embedding_gptj_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + } + + CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) + }); +} diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp new file mode 100644 index 0000000000000..bba044087f37c --- /dev/null +++ b/csrc/cpu/pybind.cpp @@ -0,0 +1,73 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // vLLM custom ops + pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + + // Attention ops + ops.def( + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); + ops.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); + + // Activation ops + ops.def( + "silu_and_mul", + &silu_and_mul, + "Activation function used in SwiGLU."); + ops.def( + "gelu_and_mul", + &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def( + "gelu_tanh_and_mul", + &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def( + "gelu_new", + &gelu_new, + "GELU implementation used in GPT-2."); + ops.def( + "gelu_fast", + &gelu_fast, + "Approximate GELU implementation."); + + // Layernorm + ops.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + ops.def( + "fused_add_rms_norm", + &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); + + // Rotary embedding + ops.def( + "rotary_embedding", + &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + + // Cache ops + pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); + cache_ops.def( + "swap_blocks", + &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def( + "copy_blocks", + ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def( + "reshape_and_cache", + &reshape_and_cache, + "Reshape the key and value tensors and cache them"); +} diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index ea30fa2747838..e56b4d2204005 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -59,6 +59,8 @@ __global__ void rms_norm_kernel( template struct _typeConvert { static constexpr bool exists = false; }; +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +// CUDA < 12.0 runs into issues with packed type conversion template<> struct _typeConvert { static constexpr bool exists = true; @@ -85,8 +87,8 @@ struct _typeConvert { __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } }; -#endif - +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. diff --git a/csrc/ops.h b/csrc/ops.h index d5d6e240da7c4..41ecc1e89371b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -14,7 +14,8 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + float kv_scale); void paged_attention_v2( torch::Tensor& out, @@ -31,7 +32,8 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + float kv_scale); void rms_norm( torch::Tensor& out, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a5c6439fd6909..de02afc162113 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -91,9 +91,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &reshape_and_cache, "Reshape the key and value tensors and cache them"); cache_ops.def( - "convert_fp8_e5m2", - &convert_fp8_e5m2, - "Convert the key and value cache to fp8_e5m2 data type"); + "convert_fp8", + &convert_fp8, + "Convert the key and value cache to fp8 data type"); // Cuda utils pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); diff --git a/csrc/quantization/fp8/amd_detail/hip_float8.h b/csrc/quantization/fp8/amd_detail/hip_float8.h new file mode 100644 index 0000000000000..87c7c9ce66100 --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/hip_float8.h @@ -0,0 +1,167 @@ +#pragma once + +#ifdef __HIPCC__ +#include +#else +#include +#include +#include +#include +#endif + +#include "hip_float8_impl.h" + +struct alignas(1) hip_fp8 +{ + struct from_bits_t + { + }; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) + { + } + +#ifdef __HIP__MI300__ + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) + { + } + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) + { + } + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) + { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) + { + } + +#ifdef __HIP__MI300__ + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const + { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); + } +}; + +namespace std +{ +inline hip_fp8 sin(hip_fp8 a) +{ + return hip_fp8(sinf(float(a))); +} +inline hip_fp8 cos(hip_fp8 a) +{ + return hip_fp8(cosf(float(a))); +} +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) +{ + return a; +} +} // namespace std + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) +{ + return os << float(f8); +} + +// all + operator overloading with mixed types +// mixed types, always converts to f32, does computation in f32, and returns float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) +{ + return (fa + float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) +{ + return (float(a) + fb); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) +{ + return hip_fp8(float(a) + float(b)); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) +{ + return a = hip_fp8(float(a) + float(b)); +} + +// overloading multiplication, always returns float, +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) +{ + return float(a) * float(b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) +{ + return (a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) +{ + return (float(a) * b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) +{ + return ((float)a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) +{ + return ((float)a * float(b)); +} + +// overloading for compare +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) +{ + return (a.data == b.data); +} +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) +{ + return (a.data != b.data); +} + +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) +{ + return static_cast(a) >= static_cast(b); +} +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) +{ + return static_cast(a) > static_cast(b); +} diff --git a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h b/csrc/quantization/fp8/amd_detail/hip_float8_impl.h new file mode 100644 index 0000000000000..e05905b4e49e8 --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/hip_float8_impl.h @@ -0,0 +1,316 @@ +#pragma once + +#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#define __HIP__MI300__ +#endif + +#ifdef __HIPCC__ +#define HIP_FP8_HOST_DEVICE __host__ __device__ +#define HIP_FP8_HOST __host__ +#define HIP_FP8_DEVICE __device__ +#else +#define HIP_FP8_HOST_DEVICE +#define HIP_FP8_HOST +#define HIP_FP8_DEVICE +#endif + +namespace hip_fp8_impl +{ + +#ifdef __HIP__MI300__ +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) +{ + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; +} +#endif // __HIP__MI300__ + +HIP_FP8_HOST inline int clz(uint32_t x) +{ + return __builtin_clz(x); +} +#if defined(__HIPCC__) || defined(__CUDA_ARCH__) +HIP_FP8_DEVICE inline int clz(uint32_t x) +{ + return __clz(x); +} +#endif + +template +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) +{ +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } + } else { + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } + } + } else { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } else { + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we +mostly concern fp16 here. In this case, f8 is usually in denormal. But there +could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has +exponent bias 16. It means that there are some numbers in fp16 denormal but they +are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 +(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal + range. For example fp8 nanoo mode, denormal exponent is -7, but if the + fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference + // for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. +*/ + + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that + // is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) +{ +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); + + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + + T fInf, fNegInf, fNaN, fNeg0; + +#ifdef __HIPCC__ + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else +#endif + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; + } + } else { + if (x == 0x80) { + return fNeg0; + } + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; + return reinterpret_cast(retval); + } + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); +} + +} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd_detail/quant_utils.cuh new file mode 100644 index 0000000000000..894160972d9f4 --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/quant_utils.cuh @@ -0,0 +1,517 @@ +#pragma once +#include "hip_float8.h" + +#include +#include +#include + +#include "../../../attention/dtype_float32.cuh" +#include "../../../attention/dtype_bfloat16.cuh" + +namespace vllm +{ +namespace fp8_e4m3 { +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +template +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) +{ + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; +#else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; +#endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) +{ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const uint2& a) +{ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); +} + +using __nv_bfloat162 = __hip_bfloat162; + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) +{ + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) +{ + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) +{ + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float vec_conversion(const uint8_t& a) +{ + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion(const uint16_t& a) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0]; + res.y = f2[1]; + return res; +#else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; +#endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) +{ + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion(const uint2& a) +{ + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) +{ + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) +{ + hip_fp8 res{__bfloat162float(a)}; + return res.data; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const float& a) +{ + hip_fp8 f8(a); + return f8.data; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion(const uint32_t& a) +{ + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +// float2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +// Float4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; +} + +// Float4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +// Float8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +// float2 -> bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) +{ + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; +} + +// Float4 -> bfloat162x2 +template <> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_& a) +{ + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; +} + +// Float8 -> bfloat162x4 +template <> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_& a) +{ + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; +} + + +/* Scaled and vectorized conversions, for data exchange between high and low precision domains + + Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) + s.t. + Quantize(HP / scale) => FP8 + Dequant(FP8) * scale => HP + + */ + +// fp8 -> half +template <> +__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, const float scale) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, const float scale) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; +#else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion(static_cast(a >> 8U), scale); + return tmp.u32; +#endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, const float scale) +{ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, const float scale) +{ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); +} + +using __nv_bfloat162 = __hip_bfloat162; + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) +{ + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, const float scale) +{ + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, const float scale) +{ + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, const float scale) +{ + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, const float scale) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; +#else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), scale); + return res; +#endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale) +{ + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, const float scale) +{ + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + + +/* Quantize(HP / scale) => FP8 */ + +// TODO(Hai): vectorized to add + +// half -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, const float scale) +{ + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data)/scale}; + return f8.data; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, const float scale) +{ + hip_fp8 res{__bfloat162float(a)/scale}; + return res.data; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, const float scale) +{ + hip_fp8 f8(a/scale); + return f8.data; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, const float scale) +{ + Float4_ tmp = scaled_vec_conversion(a, scale); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +} +} // namespace vllm diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 96749b9327d7a..0e76763a87b7c 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -7,4 +7,6 @@ sphinx-argparse # packages to install to build the documentation pydantic -f https://download.pytorch.org/whl/cpu -torch \ No newline at end of file +torch +py-cpuinfo +transformers diff --git a/docs/source/conf.py b/docs/source/conf.py index 61d8e55d2cc6c..44cda7c99cdd5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,13 +11,10 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. import logging -import os import sys from sphinx.ext import autodoc -sys.path.insert(0, os.path.abspath(os.path.join('..', '..'))) - logger = logging.getLogger(__name__) # -- Project information ----------------------------------------------------- @@ -75,6 +72,7 @@ html_theme_options = { # Mock out external dependencies here. autodoc_mock_imports = [ + "cpuinfo", "torch", "transformers", "psutil", diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst new file mode 100644 index 0000000000000..ba8b0645adcdf --- /dev/null +++ b/docs/source/getting_started/cpu-installation.rst @@ -0,0 +1,87 @@ +.. _installation_cpu: + +Installation with CPU +======================== + +vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. + +Table of contents: + +#. :ref:`Requirements ` +#. :ref:`Quick start using Dockerfile ` +#. :ref:`Build from source ` +#. :ref:`Performance tips ` + +.. _cpu_backend_requirements: + +Requirements +------------ + +* OS: Linux +* Compiler: gcc/g++>=12.3.0 (recommended) +* Instruction set architecture (ISA) requirement: AVX512 is required. + +.. _cpu_backend_quick_start_dockerfile: + +Quick start using Dockerfile +---------------------------- + +.. code-block:: console + + $ docker build -f Dockerfile.cpu -t vllm-cpu-env --shm-size=4g . + $ docker run -it \ + --rm \ + --network=host \ + --cpuset-cpus= \ + --cpuset-mems= \ + vllm-cpu-env + +.. _build_cpu_backend_from_source: + +Build from source +----------------- + +- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: + +.. code-block:: console + + $ sudo apt-get update -y + $ sudo apt-get install -y gcc-12 g++-12 + $ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +- Second, install Python packages for vLLM CPU backend building: + +.. code-block:: console + + $ pip install --upgrade pip + $ pip install wheel packaging ninja setuptools>=49.4.0 numpy + $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + +- Finally, build and install vLLM CPU backend: + +.. code-block:: console + + $ VLLM_TARGET_DEVICE=cpu python setup.py install + +.. note:: + - BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support. + + - AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16. + + - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building. + +.. _cpu_backend_performance_tips: + +Performance tips +----------------- + +- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. + +- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription. + +- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading. + +- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful. + + + diff --git a/docs/source/index.rst b/docs/source/index.rst index 5196ef062dc19..5d5d52696ba34 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -63,6 +63,7 @@ Documentation getting_started/installation getting_started/amd-installation getting_started/neuron-installation + getting_started/cpu-installation getting_started/quickstart .. toctree:: @@ -90,7 +91,8 @@ Documentation :caption: Quantization quantization/auto_awq - quantization/fp8_e5m2_kv_cache + quantization/fp8_e5m2_kvcache + quantization/fp8_e4m3_kvcache .. toctree:: :maxdepth: 2 diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 45ef0340aae25..a82c2cef10e83 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor Start by forking our `GitHub`_ repository and then :ref:`build it from source `. This gives you the ability to modify the codebase and test your model. +.. tip:: + If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below. 1. Bring your model code ------------------------ @@ -94,3 +96,28 @@ This method should load the weights from the HuggingFace's checkpoint file and a ---------------------- Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py `_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py `_. + +6. Out-of-Tree Model Integration +-------------------------------------------- + +We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5. + +Just add the following lines in your code: + +.. code-block:: python + + from vllm import ModelRegistry + from your_code import YourModelForCausalLM + ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) + +If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code: + +.. code-block:: python + + from vllm import ModelRegistry + from your_code import YourModelForCausalLM + ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) + import runpy + runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') + +Save the above code in a file and run it with `python your_file.py args`. diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index 9f5f672ae4f34..d8a7ac72e0175 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -118,3 +118,19 @@ Below, you can find an explanation of every engine argument for vLLM: .. option:: --quantization (-q) {awq,squeezellm,None} Method used to quantize the weights. + +Async Engine Arguments +---------------------- +Below are the additional arguments related to the asynchronous engine: + +.. option:: --engine-use-ray + + Use Ray to start the LLM engine in a separate process as the server process. + +.. option:: --disable-log-requests + + Disable logging requests. + +.. option:: --max-log-len + + Max number of prompt characters or prompt ID numbers being printed in log. Defaults to unlimited. \ No newline at end of file diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 9c2f5ba458eb4..e7bfdcb65316e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -83,6 +83,10 @@ Alongside each architecture, we include some popular models that use it. - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. - ✅︎ + * - :code:`MiniCPMForCausalLM` + - MiniCPM + - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. + - * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. diff --git a/docs/source/quantization/fp8_e4m3_kvcache.rst b/docs/source/quantization/fp8_e4m3_kvcache.rst new file mode 100644 index 0000000000000..fd71c00b7bf89 --- /dev/null +++ b/docs/source/quantization/fp8_e4m3_kvcache.rst @@ -0,0 +1,49 @@ +.. _fp8_e4m3_kvcache: + +FP8 E4M3 KV Cache +================== + +Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, +improving throughput. OCP (Open Compute Project www.opencompute.org) specifies two common 8-bit floating point data formats: E5M2 +(5 exponent bits and 2 mantissa bits) and E4M3FN (4 exponent bits and 3 mantissa bits), often shortened as E4M3. One benefit of +the E4M3 format over E5M2 is that floating point numbers are represented in higher precision. However, the small dynamic range of +FP8 E4M3 (±240.0 can be represented) typically necessitates the use of a higher-precision (typically FP32) scaling factor alongside +each quantized tensor. For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling +factors of a finer granularity (e.g. per-channel). + +These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If +this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an +unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO). + +To install AMMO (AlgorithMic Model Optimization): + +.. code-block:: console + + $ pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo + +Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent silicon +offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to and from fp32, fp16, bf16, etc. +Thus, LLM inference is greatly accelerated with minimal accuracy loss. + + +Here is an example of how to enable this feature: + +.. code-block:: python + + # two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to + # https://github.com/vllm-project/vllm/blob/main/examples/fp8/README.md to generate kv_cache_scales.json of your own. + + from vllm import LLM, SamplingParams + sampling_params = SamplingParams(temperature=1.3, top_p=0.8) + llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") + prompt = "London is the capital of" + out = llm.generate(prompt, sampling_params)[0].outputs[0].text + print(out) + + # output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial, + # output w/o scaling factors: England, located in the southeastern part of the country. It is known + +Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type. + diff --git a/docs/source/quantization/fp8_e5m2_kv_cache.rst b/docs/source/quantization/fp8_e5m2_kvcache.rst similarity index 83% rename from docs/source/quantization/fp8_e5m2_kv_cache.rst rename to docs/source/quantization/fp8_e5m2_kvcache.rst index f1eeb59550952..337252a00aef2 100644 --- a/docs/source/quantization/fp8_e5m2_kv_cache.rst +++ b/docs/source/quantization/fp8_e5m2_kvcache.rst @@ -1,4 +1,4 @@ -.. _fp8_e5m2_kv_cache: +.. _fp8_kv_cache: FP8 E5M2 KV Cache ================== @@ -21,7 +21,7 @@ Here is an example of how to enable this feature: # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. - llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2") + llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -31,3 +31,6 @@ Here is an example of how to enable this feature: generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type. + diff --git a/examples/fp8/README.md b/examples/fp8/README.md new file mode 100644 index 0000000000000..84ad76c71862e --- /dev/null +++ b/examples/fp8/README.md @@ -0,0 +1,96 @@ +# FP8 KV Cache + +This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms. + +## Prerequisites + +- Python 3.x +- PyTorch +- NumPy +- Hugging Face Transformers +- Hugging Face Hub +- AMMO + +Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps: +1. Install all necessary prerequisites and dependencies. +2. Convert HF model into a quantized HF model. +3. Extract KV Cache Scaling Factors from quantized HF model. +4. Load KV Cache Scaling Factors into VLLM. + +### 2. Convert HF model into a quantized HF model. +Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md). + +`quantize.py` (examples/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format). + +The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/fp8/quantizer/README.md`. + +### 3. Extract KV Cache Scaling Factors from quantized HF model. +`extract_scales.py` (examples/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following: +1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename. + +2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM. + +3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks. + +```python +# prerequisites: +# - Quantized HF LLaMa 2 model +python3 examples/fp8/extract_scales.py --help +Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE] + +KV Scale Extraction Example + +optional arguments: +--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU). +Optional arguments: +--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None) +--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto) +--revision: Specify the model's revision number. (Default: None) +--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None) +--output_name: Specify the output filename. (Default: kv_cache_scales.json) +--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None) +``` +```python +Example: +python3 examples/fp8/extract_scales.py --quantized_model --tp_size --output_dir +``` +### 4. Load KV Cache Scaling Factors into VLLM. +This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8. +```python +# prerequisites: +# - LLaMa 2 kv_cache_scales.json file + +python3 benchmarks/benchmark_throughput.py --help +usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL] + [--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] + [--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code] + [--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}] + [--quantization-param-path KV_CACHE_quantization_param_path] + +Benchmark Throughput Example +optional arguments: + -h, --help show this help message and exit + --backend {vllm,hf,mii} + --dataset DATASET Path to the dataset. + --input-len INPUT_LEN Input prompt length for each request + --output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset. + --model MODEL + --tokenizer TOKENIZER + --quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None} + --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE + --n N Number of generated sequences per prompt. + --use-beam-search + --num-prompts NUM_PROMPTS Number of prompts to process. + --seed SEED + --hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend. + --trust-remote-code trust remote code from huggingface + --max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model. + --dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. + --enforce-eager enforce eager execution + --kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria. + --quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria. +``` +``` +Example: +python3 benchmarks/benchmark_throughput.py --input-len --output-len -tp --kv-cache-dtype fp8 --quantization-param-path --model +```python diff --git a/examples/fp8/extract_scales.py b/examples/fp8/extract_scales.py new file mode 100644 index 0000000000000..5e5b31265e3af --- /dev/null +++ b/examples/fp8/extract_scales.py @@ -0,0 +1,367 @@ +import argparse +import glob +import json +import os +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple + +import numpy as np +import torch +from safetensors.torch import safe_open + +from vllm.model_executor.layers.quantization.schema import QuantParamSchema + + +# Adapted from vllm/model_executor/weight_utils.py +# The main differences are that we add the NPZ format and simplify +# its functionality drastically for our purposes (e.g. we assume that +# the quantized model exists locally and there is no need to download it) +def _prepare_hf_weights( + quantized_model_dir: str, + load_format: str = "auto", + fall_back_to_pt: bool = True, +) -> Tuple[str, List[str], bool]: + if not os.path.isdir(quantized_model_dir): + raise FileNotFoundError( + f"The quantized model directory `{quantized_model_dir}` " + "does not exist.") + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == "auto": + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == "safetensors": + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == "pt": + allow_patterns = ["*.pt"] + elif load_format == "npz": + allow_patterns = ["*.npz"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob( + os.path.join(quantized_model_dir, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if not use_safetensors: + # Exclude files that are not needed for inference. + # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{quantized_model_dir}`") + + return hf_weights_files, use_safetensors + + +# Adapted from vllm/model_executor/weight_utils.py +def _hf_tensorfile_iterator(filename: str, load_format: str, + use_safetensors: bool): + if load_format == "npz": + assert not use_safetensors + with np.load(filename) as data: + for name in data.files: + param = torch.from_numpy(data[name]) + yield name, param + elif use_safetensors: + with safe_open(filename, framework="pt") as f: + for name in f.keys(): # NOQA: SIM118 + param = f.get_tensor(name) + yield name, param + else: + state = torch.load(filename, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + + +def _kv_scales_extractor( + hf_tensor_files: Iterable[str], + use_safetensors: bool, + rank_keyword: str = "rank", + expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]: + """ + Given a list of files containing tensor data, attempt to extract KV cache + scales from these files. Intended as a helper function taking in the output + from _prepare_hf_weights. + Args: + rank_keyword Matches the number immediately after this keyword in the + tensor filename to determine the TP rank corresponding + to said tensor file + expected_tp_size If specified, the TP size of the tensor files is checked + against this and an error is raised if they don't match. + Returns a dictionary mapping TP ranks to their relevant KV cache scales. + The per-rank scales are themselves represented as a dictionary of layer + indices to the respective per-layer scale. + """ + for char in rank_keyword: + assert not char.isdecimal( + ), f"Rank keyword {rank_keyword} contains a numeric character!" + rank_scales_map = {} + for tensor_file in hf_tensor_files: + try: + rank_idx = tensor_file.find(rank_keyword) + if rank_idx != -1: + start_idx = rank_idx + len(rank_keyword) + stop_idx = start_idx + while stop_idx < len( + tensor_file) and tensor_file[stop_idx].isdecimal(): + stop_idx += 1 + if stop_idx == start_idx: + raise RuntimeError("Did not find rank # in filename.") + rank = int(tensor_file[start_idx:stop_idx]) + elif len(hf_tensor_files) == 1: + # Since there is only one tensor file, we can assume + # that it's intended for TP rank 0 + rank = 0 + else: + raise RuntimeError( + f"Filename does not contain '{rank_keyword}'.") + except RuntimeError: + print("Unable to determine TP rank " + f"corresponding to file '{tensor_file}'") + raise + + if rank not in rank_scales_map: + layer_scales_map = {} + rank_scales_map[rank] = layer_scales_map + else: + raise RuntimeError( + f"Tensor file '{tensor_file}' shares TP rank {rank} " + "with another tensor file.") + + module_delimiter = ":" if args.load_format == "npz" else "." + for name, param in _hf_tensorfile_iterator(tensor_file, + args.load_format, + use_safetensors): + if "kv_cache_scaling_factor" in name: + nums = [ + int(s) for s in name.split(module_delimiter) + if s.isdecimal() + ] + assert len( + nums) == 1, f"Could not determine layer idx for {name}" + layer_idx = nums[0] + assert layer_idx not in layer_scales_map, f"Duplicate scaling"\ + f" factor corresponding to layer {layer_idx}" + try: + layer_scales_map[layer_idx] = param.item() + except RuntimeError: + print( + "This utility supports only per-tensor scalar scales " + f"for now. The tensor\n {name} = {param} \nis an " + "invalid scale factor.") + raise + + if all( + len(layer_scales_map) == 0 + for layer_scales_map in rank_scales_map.values()): + # Note: this is true even if the rank_scales_map is empty + print("WARNING: No KV cache scale factors found. No output saved.") + return None + empirical_tp_world_size = max(rank_scales_map.keys()) + 1 + if expected_tp_size is not None: + assert expected_tp_size == empirical_tp_world_size, \ + f"User expected TP world size = {expected_tp_size} " \ + "from model but tool is expecting TP world size = " \ + f"{empirical_tp_world_size} from model instead." + for i in range(empirical_tp_world_size): + assert i in rank_scales_map, "Expected TP world size = "\ + f"{empirical_tp_world_size} but did not find KV " \ + f"cache scaling factors for TP rank {i}" + print(f"Found TP world size = {empirical_tp_world_size} " + "when extracting KV cache scales!") + return rank_scales_map + + +def _metadata_extractor(quantized_model_dir: str, + metadata_extract_fns: \ + Dict[str, Callable[[Dict[str, Any]], Any]]) \ + -> Dict[str, Any]: + """ + Given a directory containing quantized model files, this function + aims to extract metadata from the JSON files within this directory. + Each JSON file is expected to represent a dictionary in JSON + format (referred to as a "JSON-dictionary"). Metadata extraction is + defined by a dictionary called metadata_extract_fns, where each + metadata field name is mapped to an extraction function. + + These extraction functions are designed to take a JSON-dictionary + as their only argument and return the corresponding metadata. + While extraction functions are permitted to raise exceptions, they + should only raise a KeyError or ValueError if the metadata field + cannot be extracted from the current JSON-dictionary, yet there's + a possibility of finding it in another JSON-dictionary. + + The function returns a dictionary that maps metadata fields to + their extracted data. The keys of this dictionary correspond exactly + to those in metadata_extract_fns. If any fields fail to be extracted, + their corresponding values are set to None, and a warning is printed. + """ + if not os.path.isdir(quantized_model_dir): + raise FileNotFoundError( + f"The quantized model directory `{quantized_model_dir}` " + "does not exist.") + metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json")) + + result = {} + for file in metadata_files: + with open(file) as f: + try: + metadata = json.load(f) + except json.JSONDecodeError: + print(f"Could not parse `{file}` as a valid metadata file," + " skipping it.") + continue + if not isinstance(metadata, dict): + print(f"The file `{file}` does not correspond to a " + "JSON-serialized dictionary, skipping it.") + continue + for metadata_name, extract_fn in metadata_extract_fns.items(): + try: + metadata_info = extract_fn(metadata) + if metadata_name not in result: + result[metadata_name] = metadata_info + elif metadata_info != result[metadata_name]: + raise RuntimeError( + "Metadata mismatch! Originally found " + f"{metadata_name} = {result[metadata_name]} but " + f"now found {metadata_name} = {metadata_info} in " + f"`{file}`") + except KeyError: + # It is possible that a given file does not contain some + # of our selected metadata as it could be located in some + # other metadata file. + # 'EFINAE': extract_fn failure is not an error. + pass + except ValueError: + # See above. + pass + + # Warn if we cannot find any of the requested metadata + for metadata_name in metadata_extract_fns: + if metadata_name not in result: + print("WARNING: Unable to find requested metadata field " + f"`{metadata_name}`, setting it to None.") + result[metadata_name] = None + + return result + + +def main(args): + metadata_extract_fns = { + "model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"], + "tp_size": lambda json_dict: int(json_dict["tensor_parallel"]), + "model_dtype": lambda json_dict: json_dict["dtype"] + } + recovered_metadata = _metadata_extractor(args.quantized_model, + metadata_extract_fns) + if args.tp_size is not None: + metadata_tp_size = recovered_metadata["tp_size"] + if metadata_tp_size is not None: + assert args.tp_size == metadata_tp_size, \ + f"User expected TP world size = {args.tp_size} " \ + f"but found TP world size = {metadata_tp_size} from metadata!" + expected_tp_size = args.tp_size or recovered_metadata["tp_size"] + rank_keyword = "rank" + hf_tensor_files, use_safetensors = _prepare_hf_weights( + args.quantized_model, args.load_format) + rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors, + rank_keyword, expected_tp_size) + # Postprocess: formatting to the current schema. Consider pulling it + # out into a dedicated function should it ever become more complicated. + rank_scales_map = { + rank: {k: scale[k] + for k in sorted(scale.keys())} + for rank, scale in rank_scales_map.items() + } + # TODO: Expand this with activation and weights scaling factors when + # they are used in the future + schema = QuantParamSchema( + model_type=recovered_metadata["model_type"], + kv_cache={ + "dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else + recovered_metadata["model_dtype"]), + "scaling_factor": + rank_scales_map + }, + ) + + if args.output_dir is None: + output_file = os.path.join(args.quantized_model, args.output_name) + else: + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + output_file = os.path.join(args.output_dir, args.output_name) + + with open(output_file, 'w') as f: + f.write(schema.model_dump_json(indent=4)) + print(f"Completed! KV cache scaling factors saved to {output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This simple utility extracts the " + "KV cache scaling factors from a quantized HF model " + "and saves them to a JSON file compatible with later " + "use by vLLM (pass this file to the appropriate " + "runtime typically using the argument " + "--quantization-param-path ). This is only used " + "if the KV cache dtype is FP8 and on ROCm (AMD GPU).") + parser.add_argument( + "--quantized_model", + help="Specify the directory containing a single quantized HF model. " + "It is expected that the quantization format is FP8_E4M3, for use " + "on ROCm (AMD GPU).", + required=True) + parser.add_argument( + "--load_format", + help="Optionally specify the format of the model's tensor files " + "containing the KV cache scaling factors.", + choices=["auto", "safetensors", "npz", "pt"], + default="auto") + parser.add_argument( + "--output_dir", + help="Optionally specify the output directory. By default the " + "KV cache scaling factors will be saved in the model directory, " + "however you can override this behavior here.", + default=None) + parser.add_argument( + "--output_name", + help="Optionally specify the output filename.", + # TODO: Change this once additional scaling factors are enabled + default="kv_cache_scales.json") + parser.add_argument( + "--tp_size", + help="Optionally specify the tensor-parallel (TP) size that the " + "quantized model should correspond to. If specified, during KV " + "cache scaling factor extraction the observed TP size will be " + "checked against this and an error will be raised if there is " + "a mismatch. If not specified, the quantized model's expected " + "TP size is instead inferred from the largest TP rank observed. " + "The expected TP size is cross-checked against the TP ranks " + "observed in the quantized model and an error is raised if any " + "discrepancies are found.", + default=None, + type=int) + args = parser.parse_args() + + main(args) diff --git a/examples/fp8/quantizer/README.md b/examples/fp8/quantizer/README.md new file mode 100644 index 0000000000000..8f89a74a6a367 --- /dev/null +++ b/examples/fp8/quantizer/README.md @@ -0,0 +1,32 @@ +### Quantizer Utilities +`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM: +`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py` + +### Prerequisite + +#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later +`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo` + +#### AMMO Download (code and docs) +`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz` +`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz` + +### Usage + +#### Run on H100 system for speed if FP8; number of GPUs depends on the model size + +#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache: +`python quantize.py --model_dir ./ll2-7b --dtype float16 --qformat fp8 --kv_cache_dtype fp8 --output_dir ./ll2_7b_fp8 --calib_size 512 --tp_size 1` + +Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference) +``` +# ll ./ll2_7b_fp8/ +total 19998244 +drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./ +drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../ +-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json +-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz +-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors +# +``` + diff --git a/examples/fp8/quantizer/quantize.py b/examples/fp8/quantizer/quantize.py new file mode 100644 index 0000000000000..cee13b4c9c863 --- /dev/null +++ b/examples/fp8/quantizer/quantize.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501 +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from examples/quantization/hf_ptq.py +""" + +import argparse +import copy +import json +import random +import time + +import ammo.torch.quantization as atq +import numpy as np +import torch +from ammo.torch.export import export_model_config +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer + +RAND_SEED = 1234 +MAX_SEQ_LEN = 2048 + +EMPTY_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "enable": False, + }, + "*input_quantizer": { + "enable": False + }, + "*lm_head*": { + "enable": False + }, + "*output_layer*": { + "enable": False + }, + "default": { + "enable": False + }, + }, + "algorithm": "max", +} + +KV_CACHE_CFG = { + "*.query_key_value.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.Wqkv.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.W_pack.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.c_attn.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.k_proj.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, + "*.v_proj.output_quantizer": { + "num_bits": 8, + "axis": None, + "enable": True + }, +} + +QUANT_CFG_CHOICES = { + "int8_sq": atq.INT8_SMOOTHQUANT_CFG, + "fp8": atq.FP8_DEFAULT_CFG, + "int4_awq": atq.INT4_AWQ_CFG, + "w4a8_awq": atq.W4A8_AWQ_BETA_CFG, + "int8_wo": EMPTY_CFG, + "int4_wo": EMPTY_CFG, + "full_prec": EMPTY_CFG, +} + +MODEL_NAME_PATTERN_MAP = { + "GPT2": "gpt2", + "Xverse": "llama", + "Llama": "llama", + "Mistral": "llama", + "GPTJ": "gptj", + "FalconForCausalLM": "falcon", + "RWForCausalLM": "falcon", + "baichuan": "baichuan", + "MPT": "mpt", + "Bloom": "bloom", + "ChatGLM": "chatglm", + "QWen": "qwen", +} + + +def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None): + print(f"Initializing tokenizer from {ckpt_path}") + tokenizer = AutoTokenizer.from_pretrained( + ckpt_path, + model_max_length=max_seq_len, + padding_side="left", + trust_remote_code=True, + ) + if model_type and model_type == "qwen": + # qwen use token id 151643 as pad and eos tokens + tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) + tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) + + # can't set attribute 'pad_token' for "" + if tokenizer.pad_token != "": + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + assert (tokenizer.pad_token + is not None), f"Pad token for {model_type} cannot be set!" + + return tokenizer + + +def get_model(ckpt_path, dtype="fp16", device="cuda"): + print(f"Initializing model from {ckpt_path}") + if dtype == "bf16" or dtype == "bfloat16": + dtype = torch.bfloat16 + elif dtype == "fp16" or dtype == "float16": + dtype = torch.float16 + elif dtype == "fp32" or dtype == "float32": + dtype = torch.float32 + else: + raise NotImplementedError(f"Unknown dtype {dtype}") + + # model_kwargs = {"torch_dtype": dtype} + model_kwargs = {"torch_dtype": "auto"} + + model = AutoModelForCausalLM.from_pretrained(ckpt_path, + device_map="auto", + **model_kwargs, + trust_remote_code=True) + model.eval() + + model_dtype = next(model.parameters()).dtype + if dtype != model_dtype: + print("[TensorRT-LLM][WARNING] The manually set model data type is " + f"{dtype}, but the data type of the HuggingFace model is " + f"{model_dtype}.") + + return model + + +def get_model_type(model): + for k, v in MODEL_NAME_PATTERN_MAP.items(): + if k.lower() in type(model).__name__.lower(): + return v + return None + + +def get_calib_dataloader(data="cnn_dailymail", + tokenizer=None, + batch_size=1, + calib_size=512, + block_size=512, + device=None): + print("Loading calibration dataset") + if data == "pileval": + dataset = load_dataset( + "json", + data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", + split="train") + dataset = dataset["text"][:calib_size] + elif data == "cnn_dailymail": + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + dataset = dataset["article"][:calib_size] + else: + raise NotImplementedError + + batch_encoded = tokenizer.batch_encode_plus(dataset, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=block_size) + if device: + batch_encoded = batch_encoded.to(device) + batch_encoded = batch_encoded["input_ids"] + + calib_dataloader = DataLoader(batch_encoded, + batch_size=batch_size, + shuffle=False) + + return calib_dataloader + + +def quantize_model(model, quant_cfg, calib_dataloader=None): + + def calibrate_loop(): + if calib_dataloader is None: + return + """Adjusts weights and scaling factors based on selected algorithms.""" + for idx, data in enumerate(calib_dataloader): + print(f"Calibrating batch {idx}") + model(data) + + print("Starting quantization...") + start_time = time.time() + atq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + end_time = time.time() + print("Quantization done. Total time used: {:.2f} s.".format(end_time - + start_time)) + + return model + + +def main(args): + if not torch.cuda.is_available(): + raise EnvironmentError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + model = get_model(args.model_dir, args.dtype, args.device) + model_type = get_model_type(model) + tokenizer = get_tokenizer(args.model_dir, model_type=model_type) + + if args.qformat in ["full_prec", "int8_wo", "int4_wo" + ] and args.kv_cache_dtype is None: + print(f"No quantization applied, export {args.dtype} model") + else: + if "awq" in args.qformat: + if args.calib_size > 32: + print("AWQ calibration could take longer with calib_size = " + f"{args.calib_size}, Using calib_size=32 instead") + args.calib_size = 32 + print("\nAWQ calibration could take longer than other calibration " + "methods. Please increase the batch size to speed up the " + "calibration process. Batch size can be set by adding the " + "argument --batch_size to the command line.\n") + + calib_dataloader = get_calib_dataloader( + tokenizer=tokenizer, + batch_size=args.batch_size, + calib_size=args.calib_size, + device=args.device, + ) + + if args.qformat in QUANT_CFG_CHOICES: + quant_cfg = QUANT_CFG_CHOICES[args.qformat] + else: + raise ValueError( + f"Unsupported quantization format: {args.qformat}") + + if "awq" in args.qformat: + quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat]) + weight_quantizer = quant_cfg["quant_cfg"][ + "*weight_quantizer"] # type: ignore + if isinstance(weight_quantizer, list): + weight_quantizer = weight_quantizer[0] + weight_quantizer["block_sizes"][-1] = args.awq_block_size + + if args.kv_cache_dtype is not None: + if args.kv_cache_dtype == "fp8": + for value in KV_CACHE_CFG.values(): + value.update({"num_bits": (4, 3)}) # type: ignore + quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore + + print(quant_cfg) + + model = quantize_model(model, quant_cfg, calib_dataloader) + + with torch.inference_mode(): + if model_type is None: + print(f"Unknown model type {type(model).__name__}. Continue " + "exporting...") + model_type = f"unknown:{type(model).__name__}" + + export_path = args.output_dir + start_time = time.time() + + if args.qformat == "int4_awq" and model_type == "qwen": + torch.save(model.state_dict(), export_path) + else: + export_npz = (model_type not in [ + 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan' + ]) + + # export safetensors + export_model_config( + model, + model_type, + getattr(torch, args.dtype), + export_dir=export_path, + inference_tensor_parallel=args.tp_size, + inference_pipeline_parallel=args.pp_size, + # export_tensorrt_llm_config=(not export_npz), + export_tensorrt_llm_config=False, + export_npz=export_npz) + + # Workaround for wo quantization + if args.qformat in ["int8_wo", "int4_wo", "full_prec"]: + with open(f"{export_path}/config.json", 'r') as f: + tensorrt_llm_config = json.load(f) + if args.qformat == "int8_wo": + tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' + elif args.qformat == "int4_wo": + tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16' + else: + tensorrt_llm_config["quantization"]["quant_algo"] = None + with open(f"{export_path}/config.json", "w") as f: + json.dump(tensorrt_llm_config, f, indent=4) + + end_time = time.time() + print("Quantized model exported to {} \nTotal time used {:.2f} s.". + format(export_path, end_time - start_time)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_dir", + help="Specify where the HuggingFace model is", + required=True) + parser.add_argument("--device", default="cuda") + parser.add_argument("--dtype", help="Model data type.", default="float16") + parser.add_argument( + "--qformat", + help="Quantization format.", + default="full_prec", + choices=[ + "fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo", + "full_prec" + ], + ) + parser.add_argument("--batch_size", + help="Batch size for calibration.", + type=int, + default=1) + parser.add_argument("--calib_size", + help="Number of samples for calibration.", + type=int, + default=512) + parser.add_argument("--output_dir", default="exported_model") + parser.add_argument("--tp_size", type=int, default=1) + parser.add_argument("--pp_size", type=int, default=1) + parser.add_argument("--awq_block_size", type=int, default=128) + parser.add_argument("--kv_cache_dtype", + help="KV Cache dtype.", + default=None, + choices=["int8", "fp8", None]) + args = parser.parse_args() + + main(args) diff --git a/pyproject.toml b/pyproject.toml index 9d042601ccb8e..2a00d6796ee02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.1.2", + "torch == 2.2.1", "wheel", ] build-backend = "setuptools.build_meta" @@ -13,6 +13,10 @@ build-backend = "setuptools.build_meta" [tool.ruff] # Allow lines to be as long as 80. line-length = 80 +exclude = [ + # External file, leaving license intact + "examples/fp8/quantizer/quantize.py" +] [tool.ruff.lint] select = [ diff --git a/requirements-build.txt b/requirements-build.txt index a8efcde590bbf..2bc07fb152aac 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -3,5 +3,5 @@ cmake>=3.21 ninja packaging setuptools>=49.4.0 -torch==2.1.2 +torch==2.2.1 wheel diff --git a/requirements.txt b/requirements-common.txt similarity index 57% rename from requirements.txt rename to requirements-common.txt index 7e84034dde439..ff053388a23e1 100644 --- a/requirements.txt +++ b/requirements-common.txt @@ -1,20 +1,14 @@ -cmake>=3.21 +cmake >= 3.21 ninja # For faster builds. psutil -ray >= 2.9 sentencepiece # Required for LLaMA tokenizer. numpy -torch == 2.1.2 requests -psutil py-cpuinfo transformers >= 4.39.1 # Required for StarCoder2 & Llava. -xformers == 0.0.23.post1 # Required for CUDA 12.1. fastapi uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 -pynvml == 11.5.0 -triton >= 2.1.0 -outlines == 0.0.34 -tiktoken == 0.6.0 # Required for DBRX tokenizer +tiktoken == 0.6.0 # Required for DBRX tokenizer +outlines == 0.0.34 # Requires torch >= 2.1.0 \ No newline at end of file diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000000000..36d20bc9473ea --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,6 @@ +# Common dependencies +-r requirements-common.txt + +# Dependencies for x86_64 CPUs +torch == 2.2.1+cpu +triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. diff --git a/requirements-cuda.txt b/requirements-cuda.txt new file mode 100644 index 0000000000000..6ee75e8139c04 --- /dev/null +++ b/requirements-cuda.txt @@ -0,0 +1,10 @@ +# Common dependencies +-r requirements-common.txt + +# Dependencies for NVIDIA GPUs +ray >= 2.9 +pynvml == 11.5.0 +vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library +torch == 2.2.1 +xformers == 0.0.25 # Requires PyTorch 2.2.1 +triton >= 2.1.0 diff --git a/requirements-neuron.txt b/requirements-neuron.txt index 6828bd4fd1fce..92b705b4b2d67 100644 --- a/requirements-neuron.txt +++ b/requirements-neuron.txt @@ -1,12 +1,7 @@ -sentencepiece # Required for LLaMA tokenizer. -numpy +# Common dependencies +-r requirements-common.txt + +# Dependencies for Neuron devices transformers-neuronx >= 0.9.0 torch-neuronx >= 2.1.0 neuronx-cc -fastapi -uvicorn[standard] -pydantic >= 2.0 # Required for OpenAI server. -prometheus_client >= 0.18.0 -requests -psutil -py-cpuinfo \ No newline at end of file diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 0dc2f0e664114..903845b64d98f 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -1,17 +1,5 @@ -cmake>=3.21 -ninja # For faster builds. -typing-extensions>=4.8.0 -starlette -requests -py-cpuinfo -psutil +# Common dependencies +-r requirements-common.txt + +# Dependencies for AMD GPUs ray == 2.9.3 -sentencepiece # Required for LLaMA tokenizer. -numpy -tokenizers>=0.15.0 -transformers >= 4.39.1 # Required for StarCoder2 & Llava. -fastapi -uvicorn[standard] -pydantic >= 2.0 # Required for OpenAI server. -prometheus_client >= 0.18.0 -outlines == 0.0.34 diff --git a/setup.py b/setup.py index 813f865ffb63e..616324e812bcd 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,8 @@ from torch.utils.cpp_extension import CUDA_HOME ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) +# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] +VLLM_TARGET_DEVICE = os.getenv("VLLM_TARGET_DEVICE", "cuda") # vLLM only supports Linux platform assert sys.platform.startswith( @@ -112,6 +114,7 @@ class cmake_build_ext(build_ext): '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir), '-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp), + '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), ] verbose = bool(int(os.getenv('VERBOSE', '0'))) @@ -186,11 +189,14 @@ class cmake_build_ext(build_ext): def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None - return has_cuda and not (_is_neuron() or _is_tpu()) + return (VLLM_TARGET_DEVICE == "cuda" + and has_cuda + and not (_is_neuron() or _is_tpu())) def _is_hip() -> bool: - return torch.version.hip is not None + return (VLLM_TARGET_DEVICE == "cuda" + or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None def _is_neuron() -> bool: @@ -206,8 +212,12 @@ def _is_tpu() -> bool: return True # FIXME +def _is_cpu() -> bool: + return VLLM_TARGET_DEVICE == "cpu" + + def _build_custom_ops() -> bool: - return _is_cuda() or _is_hip() + return _is_cuda() or _is_hip() or _is_cpu() def _install_punica() -> bool: @@ -307,6 +317,8 @@ def get_vllm_version() -> str: version += f"+neuron{neuron_version_str}" elif _is_tpu(): version += "+tpu" + elif _is_cpu(): + version += "+cpu" else: raise RuntimeError("Unknown runtime environment") @@ -324,22 +336,40 @@ def read_readme() -> str: def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" + + def _read_requirements(filename: str) -> List[str]: + with open(get_path(filename)) as f: + requirements = f.read().strip().split("\n") + resolved_requirements = [] + for line in requirements: + if line.startswith("-r "): + resolved_requirements += _read_requirements(line.split()[1]) + else: + resolved_requirements.append(line) + return resolved_requirements + if _is_cuda(): - with open(get_path("requirements.txt")) as f: - requirements = f.read().strip().split("\n") + requirements = _read_requirements("requirements-cuda.txt") + cuda_major = torch.version.cuda.split(".")[0] + modified_requirements = [] + for req in requirements: + if "vllm-nccl-cu12" in req: + modified_requirements.append( + req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}")) + else: + modified_requirements.append(req) + requirements = modified_requirements elif _is_hip(): - with open(get_path("requirements-rocm.txt")) as f: - requirements = f.read().strip().split("\n") + requirements = _read_requirements("requirements-rocm.txt") elif _is_neuron(): - with open(get_path("requirements-neuron.txt")) as f: - requirements = f.read().strip().split("\n") + requirements = _read_requirements("requirements-neuron.txt") elif _is_tpu(): - with open(get_path("requirements-tpu.txt")) as f: - requirements = f.read().strip().split("\n") + requirements = _read_requirements("requirements-tpu.txt") + elif _is_cpu(): + requirements = _read_requirements("requirements-cpu.txt") else: raise ValueError( - "Unsupported platform, please use CUDA, ROCM or Neuron.") - + "Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") return requirements diff --git a/tests/conftest.py b/tests/conftest.py index 770da1e6f14b8..e00f3eb871e37 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,10 +55,24 @@ def cleanup(): torch.cuda.empty_cache() +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + """Allow subdirectories to skip global cleanup by overriding this fixture. + This can provide a ~10x speedup for non-GPU unit tests since they don't need + to initialize torch. + """ + + if request.node.get_closest_marker("skip_global_cleanup"): + return False + + return True + + @pytest.fixture(autouse=True) -def cleanup_fixture(): +def cleanup_fixture(should_do_global_cleanup_after_test: bool): yield - cleanup() + if should_do_global_cleanup_after_test: + cleanup() @pytest.fixture(scope="session") diff --git a/tests/core/block/conftest.py b/tests/core/block/conftest.py new file mode 100644 index 0000000000000..0464d6a74da61 --- /dev/null +++ b/tests/core/block/conftest.py @@ -0,0 +1,12 @@ +import pytest + + +@pytest.fixture() +def should_do_global_cleanup_after_test() -> bool: + """Disable the global cleanup fixture for tests in this directory. This + provides a ~10x speedup for unit tests that don't load a model to GPU. + + This requires that tests in this directory clean up after themselves if they + use the GPU. + """ + return False diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index e1a9dd28e5737..1d99cb5d32219 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,25 +1,10 @@ -import contextlib -import gc - import pytest -import ray -import torch +from tests.conftest import cleanup from vllm import LLM -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel) from vllm.model_executor.utils import set_random_seed -def cleanup(): - destroy_model_parallel() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - gc.collect() - torch.cuda.empty_cache() - ray.shutdown() - - @pytest.fixture def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, seed): diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 283d99fe0b193..94b65401e1dd4 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -16,7 +16,7 @@ from vllm import SamplingParams # Allow only 5 sequences of ~1024 tokens in worst case. "block_size": 16, - "forced_num_gpu_blocks": 5 * (64 + 1), + "num_gpu_blocks_override": 5 * (64 + 1), }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ @@ -77,6 +77,159 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Use a large block size to trigger more copy-on-writes. + "block_size": 32, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify beam search equality with block manager v1 and v2. + + This requires copy-on-writes; if the v1 and v2 output is the same, then + we have some confidence cow is working. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + use_beam_search=True, + best_of=2, + ) + + print('Getting token ids from block manager v1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids from block manager v2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # Our prompts will generate 128 tokens; since the prompts themselves are + # small, we don't need much KV space beyond 128. + "max_model_len": 160, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Lookahead scheduling only supported in v2 block manager. + "use_v2_block_manager": True, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + "block_size": 16, + + # Allow only 2 sequences of ~128 tokens in worst case. + # Note 8 = 128/block_size + "num_gpu_blocks_override": 2 * (8 + 1), + }, + { + "block_size": 8, + + # Allow only 2 sequences of ~128 tokens in worst case. + # Note 16 = 128/block_size + "num_gpu_blocks_override": 2 * (16 + 1), + } + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "num_lookahead_slots": 0, +}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [{ + # We run one test with block_size < lookahead_slots, one test with + # block_size > lookahead_slots + "num_lookahead_slots": 10, + }]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size): + """Verify vLLM produces the same output with greedy sampling, when lookahead + scheduling is used vs. not. + + Lookahead scheduling is not expected to modify the output, as it simply + allocates empty slots ahead of the known token ids in a sliding fashion. + + This test constrains the total number of blocks to force preemption. It also + varies the block size so that the lookahead size is less than and greater + than the block size. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids without lookahead scheduling') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with lookahead scheduling') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py new file mode 100644 index 0000000000000..1e8e4ccdfb151 --- /dev/null +++ b/tests/core/block/test_block_manager_v2.py @@ -0,0 +1,103 @@ +import pytest + +from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.core.interfaces import AllocStatus +from vllm.sequence import Logprob, SequenceStatus +from vllm.utils import chunk_list + +from ..utils import create_seq_group + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) +@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, + num_gpu_blocks: int, watermark: float): + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + ) + num_watermark_blocks = int(watermark * num_gpu_blocks) + + num_output_blocks_per_seq = 1 + + # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but + # the current implementation assumes all seqs are new prompts / don't have + # different output lens. + num_output_blocks = num_output_blocks_per_seq + + for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks): + seq_group = create_seq_group( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + ) + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + can_allocate_result = block_manager.can_allocate(seq_group) + + num_required_blocks = num_prompt_blocks + num_output_blocks + + if num_gpu_blocks - num_required_blocks < num_watermark_blocks: + assert can_allocate_result == AllocStatus.NEVER + elif num_gpu_blocks >= num_required_blocks: + assert can_allocate_result == AllocStatus.OK + else: + assert can_allocate_result == AllocStatus.LATER + + +@pytest.mark.parametrize("block_size", [1, 8]) +@pytest.mark.parametrize("prompt_len", [1, 7, 8]) +@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) +@pytest.mark.parametrize("num_lookahead_slots", [0, 10]) +def test_append_slots(block_size, prompt_len, num_slots_to_append, + num_lookahead_slots): + """Verify append_slots consumes the correct number of blocks from the block + table. + """ + + num_gpu_blocks = 1024 + watermark = 0.1 + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + watermark=watermark, + ) + + seq_group = create_seq_group( + seq_prompt_len=prompt_len, + seq_output_lens=[0], + ) + + # Allocate seq + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + + # Seq seq to RUNNING + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + # Append tokens to the sequeqnce + for token_id in range(num_slots_to_append): + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Append slots for new tokens and lookahead slots. + free_blocks_before_append = block_manager.get_num_free_gpu_blocks() + block_manager.append_slots(seq, num_lookahead_slots) + num_consumed_blocks = (free_blocks_before_append - + block_manager.get_num_free_gpu_blocks()) + + # Expect consumed blocks to be new blocks required to support the new slots. + expected_consumed_blocks = len( + chunk_list( + list( + range(prompt_len + num_slots_to_append + num_lookahead_slots)), + block_size)) - len(chunk_list(list(range(prompt_len)), block_size)) + assert num_consumed_blocks == expected_consumed_blocks diff --git a/tests/core/block/test_block_space_manager.py b/tests/core/block/test_block_space_manager.py deleted file mode 100644 index eec8cbcb38803..0000000000000 --- a/tests/core/block/test_block_space_manager.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest - -from vllm.core.block_manager_v2 import BlockSpaceManagerV2 -from vllm.core.interfaces import AllocStatus - -from ..utils import create_seq_group - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) -@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, - num_gpu_blocks: int, watermark: float): - block_manager = BlockSpaceManagerV2( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - ) - num_watermark_blocks = int(watermark * num_gpu_blocks) - - num_output_blocks_per_seq = 1 - - # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but - # the current implementation assumes all seqs are new prompts / don't have - # different output lens. - num_output_blocks = num_output_blocks_per_seq - - for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks): - seq_group = create_seq_group( - seq_prompt_lens=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - ) - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - - can_allocate_result = block_manager.can_allocate(seq_group) - - num_required_blocks = num_prompt_blocks + num_output_blocks - - if num_gpu_blocks - num_required_blocks < num_watermark_blocks: - assert can_allocate_result == AllocStatus.NEVER - elif num_gpu_blocks >= num_required_blocks: - assert can_allocate_result == AllocStatus.OK - else: - assert can_allocate_result == AllocStatus.LATER diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index a7c5aa2b1df59..3481d6b4312c1 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -498,3 +498,78 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, # After free, expect all blocks to be freed. assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + +@pytest.mark.parametrize("block_size", [1, 8]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +@pytest.mark.parametrize("num_new_tokens", [1, 16, 129]) +@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, + num_new_tokens: int, + num_lookahead_slots: int, + allocator_type: str): + """Verify correct calculation of get_num_blocks_touched_by_append_slots. + + This is done by using copy-on-write, which requires any modified block to + be copied before write if the refcount > 1. We set the refcount>1 by forking + a sequence, then measure the free blocks before and after an append. If the + number of consumed blocks equals what `get_num_blocks_touched_by_append_ + slots` returns, then the calculation is correct. + """ + + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + token_ids_to_append = list(range(num_new_tokens)) + + block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + + block_table.allocate(token_ids=token_ids, device=Device.GPU) + + # Add lookahead before fork so both sequences have the same lookahead + # blocks. + block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots) + + # Fork sequence so that every block has refcount > 1. + _ = block_table.fork() + + # Determine how many blocks should be touched. + expected_num_touched_blocks = ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=token_ids_to_append, + num_lookahead_slots=num_lookahead_slots)) + + # Measure how many blocks are touched by measuring num_free_blocks before + # and after the append. + # + # We expect append_token_ids to CoW all mutated blocks that have refcount>1. + num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU) + block_table.append_token_ids(token_ids_to_append, num_lookahead_slots) + num_consumed_blocks = (num_free_blocks_before_append - + allocator.get_num_free_blocks(Device.GPU)) + + # TODO(cade) ensure equality when num_lookahead_slots > 0. + # The reason we have < is because lookahead blocks are not copied eagerly; + # they are copied on first write. This will cause issues for beam search + + # speculative decoding. This is acceptable for now as it is a large effort + # to combine the two. To fix this, we can ensure single sequence ownership + # of lookahead blocks by appending empty slots to each block, which will + # trigger the CoW. + # + # Until then, we can accept that the consumed tokens are <= the expected + # tokens when appending with lookahead. + if num_lookahead_slots > 0: + assert num_consumed_blocks <= expected_num_touched_blocks + else: + assert num_consumed_blocks == expected_num_touched_blocks diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 93226cba1909c..62984ef4caabb 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -103,9 +103,9 @@ def test_append_slot_single_seq(): block_manager.allocate(seq_group) # Nothing to append. Sequence has no new logical blocks. - assert block_manager.can_append_slot(seq_group) + assert block_manager.can_append_slots(seq_group) before_blocks = block_manager.get_num_free_gpu_blocks() - assert not block_manager.append_slot(prompt) + assert not block_manager.append_slots(prompt) after_blocks = block_manager.get_num_free_gpu_blocks() assert before_blocks == after_blocks @@ -114,9 +114,9 @@ def test_append_slot_single_seq(): token_id = i + 5 prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - assert block_manager.can_append_slot(seq_group) + assert block_manager.can_append_slots(seq_group) before_blocks = block_manager.get_num_free_gpu_blocks() - assert not block_manager.append_slot(prompt) + assert not block_manager.append_slots(prompt) after_blocks = block_manager.get_num_free_gpu_blocks() assert before_blocks - after_blocks == 1 @@ -150,13 +150,13 @@ def test_append_slot_cow(): child.append_token_id(token_id, {token_id: Logprob(0.0)}) block_manager.fork(prompt, child) - assert block_manager.can_append_slot(seq_group) + assert block_manager.can_append_slots(seq_group) before_blocks = block_manager.get_num_free_gpu_blocks() - maybe_src_dst_block = block_manager.append_slot(child) - assert maybe_src_dst_block is not None - src_block, dst_block = maybe_src_dst_block - assert src_block != dst_block + cows = block_manager.append_slots(child) + assert cows + for src_block, dst_blocks in cows.items(): + assert src_block not in dst_blocks after_blocks = block_manager.get_num_free_gpu_blocks() assert before_blocks - after_blocks == 1 @@ -184,7 +184,7 @@ def test_fork(): token_id = 4 # Append token to child. Block is shared so copy on write occurs. child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slot(child) + block_manager.append_slots(child) assert block_manager.get_block_table( prompt) != block_manager.get_block_table(child) @@ -325,7 +325,7 @@ def test_sliding_window_multi_seq(): token_id = 4 # Append token to child. Block is shared so copy on write occurs. child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slot(child) + block_manager.append_slots(child) # assert the number of blocks allocated is correct # we will use now one block more. Each seq will use 2 blocks, @@ -335,7 +335,7 @@ def test_sliding_window_multi_seq(): token_id = 5 parent.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slot(parent) + block_manager.append_slots(parent) # assert the number of blocks allocated is correct # no change, because both sequences are still just sharing one block diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py new file mode 100644 index 0000000000000..05e62ced5898f --- /dev/null +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -0,0 +1,563 @@ +from typing import List +from unittest.mock import MagicMock + +import pytest # noqa + +from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.sequence import Logprob, SequenceGroup + +from .utils import create_dummy_prompt + + +def get_sequence_groups(scheduler_output): + return [s.seq_group for s in scheduler_output.scheduled_seq_groups] + + +def append_new_token(seq_group, token_id: int): + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + +def schedule_and_update_computed_tokens(scheduler): + metas, out = scheduler.schedule() + for s, meta in zip(out.scheduled_seq_groups, metas): + s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + return metas, out + + +def test_simple(): + """Verify basic scheduling works.""" + block_size = 4 + num_seq_group = 4 + max_model_len = 16 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + num_seq_group, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert out.num_batched_tokens == num_tokens + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + for s in running: + append_new_token(s, 1) + + # Schedule seq groups generation. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert out.num_batched_tokens == num_seq_group + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + + +def test_chunk(): + """Verify prefills are chunked properly.""" + block_size = 4 + max_seqs = 60 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Verify the second request is chunked. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 60 + # Verify it is chunked. + assert seq_group_meta[1].token_chunk_size == 4 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + # Only the first seq group has a new token appended. + append_new_token(running[0], 1) + + # One chunked prefill, and one decoding. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # The first one is decoding. + assert seq_group_meta[0].token_chunk_size == 1 + # The second one is a chunked prefill. + assert seq_group_meta[1].token_chunk_size == 56 + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 57 + + +def test_complex(): + block_size = 4 + max_seqs = 60 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # Verify the second request is chunked. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 60 + # Verify it is chunked. + assert seq_group_meta[1].token_chunk_size == 4 + assert not running[0].is_prefill() + assert running[1].is_prefill() + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + # Only the first seq group has a new token appended. + append_new_token(running[0], 1) + + # Add 2 more requsets. + for i in range(2, 4): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Decoding & chunked prefill & first chunk of 3rd request is scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 3 + # The first one is decoding. + assert seq_group_meta[0].token_chunk_size == 1 + # The second one is a chunked prefill. + assert seq_group_meta[1].token_chunk_size == 56 + # The third one is also chunked. + assert seq_group_meta[2].token_chunk_size == 7 + # Two of them are in chunked prefill. + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + # The first 2 requests are now in decodine phase. + append_new_token(running[0], 1) + assert not running[0].is_prefill() + append_new_token(running[1], 1) + assert not running[1].is_prefill() + # The third request is still in prefill stage. + assert running[2].is_prefill() + + +def test_maximal_decoding(): + """Verify decoding requests are prioritized.""" + block_size = 4 + max_seqs = 2 + max_model_len = 2 + max_num_batched_tokens = 2 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), prompt_length=2) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # The first prefill is scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 1 + assert seq_group_meta[0].token_chunk_size == 2 + assert not running[0].is_prefill() + assert running[1].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 2 + # Only the first seq group has a new token appended. + append_new_token(running[0], 1) + + # Create one more seq_group. + _, seq_group = create_dummy_prompt("3", prompt_length=2) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + # The first decoding + second chunk is scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 2 + assert seq_group_meta[0].token_chunk_size == 1 + assert seq_group_meta[1].token_chunk_size == 1 + assert not running[0].is_prefill() + assert running[1].is_prefill() + assert running[2].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 2 + append_new_token(running[0], 1) + + # Decoding + running prefill is prioritized. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 2 + assert seq_group_meta[0].token_chunk_size == 1 + assert seq_group_meta[1].token_chunk_size == 1 + assert not running[0].is_prefill() + assert not running[1].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 2 + append_new_token(running[0], 1) + append_new_token(running[1], 1) + + # Only decoding is prioritized. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 2 + assert seq_group_meta[0].token_chunk_size == 1 + assert seq_group_meta[1].token_chunk_size == 1 + assert not running[0].is_prefill() + assert not running[1].is_prefill() + assert out.num_prefill_groups == 0 + assert out.num_batched_tokens == 2 + append_new_token(running[0], 1) + append_new_token(running[1], 1) + + # After aborting the decoding request, the fcfs new prefill is prioritized. + scheduler.abort_seq_group(running[0].request_id) + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 2 + assert seq_group_meta[0].token_chunk_size == 1 + assert seq_group_meta[1].token_chunk_size == 1 + assert not running[1].is_prefill() + assert running[2].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 2 + + +def test_prompt_limit(): + """Verify max_num_batched_tokens < max_model_len is possible.""" + block_size = 4 + max_seqs = 32 + max_model_len = 64 + max_num_batched_tokens = 32 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + _, seq_group = create_dummy_prompt("1", prompt_length=48) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # The prompt length > max_num_batched_tokens should be still scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 1 + assert seq_group_meta[0].token_chunk_size == 32 + assert running[0].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 32 + + +def test_prompt_limit_exceed(): + block_size = 4 + max_seqs = 64 + max_model_len = 32 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + _, seq_group = create_dummy_prompt("2", prompt_length=48) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.ignored_seq_groups) == 1 + assert out.ignored_seq_groups[0] == seq_group + + +def test_swap(): + """Verify swapping works with chunked prefill requests""" + block_size = 4 + max_seqs = 30 + max_model_len = 200 + max_num_batched_tokens = 30 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler.add_seq_group(seq_group) + _, out = schedule_and_update_computed_tokens(scheduler) + # The request is chunked. + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + + # The last request should be swapped out. + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "1" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + + # The running prefill is now swapped. + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 0 + assert out.num_batched_tokens == 0 + assert out.blocks_to_swap_out != {} + assert out.blocks_to_swap_in == {} + + # Add 1 more task. Swap should be prioritized over new prefill. + _, seq_group = create_dummy_prompt("2", prompt_length=60) + scheduler.add_seq_group(seq_group) + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + # 3 decodes. It is swapped in. + assert out.num_batched_tokens == 30 + assert out.blocks_to_swap_in != {} + assert out.blocks_to_swap_out == {} + + +def test_running_prefill_prioritized_over_swap(): + block_size = 4 + max_seqs = 30 + max_model_len = 200 + max_num_batched_tokens = 30 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler.add_seq_group(seq_group) + _, out = schedule_and_update_computed_tokens(scheduler) + # The request is chunked. + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + + # The request should be swapped out. + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "1" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + + # The running prefill is now swapped. + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 0 + assert out.num_batched_tokens == 0 + assert out.blocks_to_swap_out != {} + assert out.blocks_to_swap_in == {} + + # Add 1 more task. Swap is not possible, so prefill is running. + scheduler.block_manager.can_swap_in = MagicMock() + scheduler.block_manager.can_swap_in.return_value = False + + _, seq_group2 = create_dummy_prompt("2", prompt_length=60) + scheduler.add_seq_group(seq_group2) + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + # 3 decodes. It is swapped in. + assert out.num_batched_tokens == 30 + assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out == {} + assert out.scheduled_seq_groups[0].seq_group == seq_group2 + + # Now although swap is possible, running prefill is prioritized. + scheduler.block_manager.can_swap_in.return_value = True + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + # 3 decodes. It is swapped in. + assert out.num_batched_tokens == 30 + assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out == {} + assert not seq_group2.is_prefill() + assert out.scheduled_seq_groups[0].seq_group == seq_group2 + append_new_token(seq_group2, 1) + + # Decoding is prioritized. + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + # 3 decodes. It is swapped in. + assert out.num_batched_tokens == 1 + assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out == {} + assert not seq_group2.is_prefill() + assert out.scheduled_seq_groups[0].seq_group == seq_group2 + append_new_token(seq_group2, 1) + + # Since we abort the sequence group, we can finally swap. + scheduler.abort_seq_group(seq_group2.request_id) + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.num_batched_tokens == 30 + assert out.blocks_to_swap_in != {} + assert out.blocks_to_swap_out == {} + + +def test_chunked_prefill_preempt(): + """Verify preempt works with chunked prefill requests""" + block_size = 4 + max_seqs = 30 + max_model_len = 200 + max_num_batched_tokens = 30 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + _, seq_group = create_dummy_prompt("1", prompt_length=60) + scheduler.add_seq_group(seq_group) + _, out = schedule_and_update_computed_tokens(scheduler) + # The request is chunked. + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + + # The request should be preempted. + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "1" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + + # The running prefill is now preempted. + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 0 + assert out.num_batched_tokens == 0 + assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == {} + + # Make sure we can reschedule preempted request. + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + assert seq_group.get_num_uncomputed_tokens() == 30 + + # We should be able to run prefill twice as it is chunked. + def cannot_append_second_group(seq_group, num_lookahead_slots): + return True + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert not seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + + +def test_chunked_prefill_max_seqs(): + block_size = 4 + max_seqs = 2 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running = [] + + _, seq_group = create_dummy_prompt("1", prompt_length=65) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + # The first prefill is chunked. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens + assert len(get_sequence_groups(out)) == 1 + + # Add new requests. + for i in range(4): + _, seq_group = create_dummy_prompt(str(i), prompt_length=65) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Make sure only 2 requests are scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert out.num_batched_tokens == max_num_batched_tokens + assert len(get_sequence_groups(out)) == 2 + assert not running[0].is_prefill() + assert running[1].is_prefill() + append_new_token(running[0], 1) + + # Although we have enough token budget, we can only schedule max_seqs. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 2 + assert seq_group_meta[1].token_chunk_size == 1 + assert out.num_batched_tokens == 3 + assert len(get_sequence_groups(out)) == max_seqs + assert not running[0].is_prefill() + assert not running[1].is_prefill() diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 88c2c37f4fb39..9588a1bead5f6 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -1,11 +1,16 @@ import time +from collections import deque from typing import List +from unittest.mock import MagicMock import pytest # noqa -from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.sequence import Logprob, SequenceGroup +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus +from vllm.core.policy import PolicyFactory +from vllm.core.scheduler import Scheduler, SchedulingBudget +from vllm.lora.request import LoRARequest +from vllm.sequence import Logprob, SequenceGroup, SequenceStatus from .utils import create_dummy_prompt @@ -14,6 +19,26 @@ def get_sequence_groups(scheduler_output): return [s.seq_group for s in scheduler_output.scheduled_seq_groups] +def append_new_token(out, token_id: int): + seq_groups = get_sequence_groups(out) + for seq_group in seq_groups: + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + +def schedule_and_update_computed_tokens(scheduler): + metas, out = scheduler.schedule() + for s, meta in zip(out.scheduled_seq_groups, metas): + s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + return metas, out + + +def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): + seq_group.update_num_computed_tokens(token_chunk_size) + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + def test_scheduler_add_seq_group(): block_size = 4 scheduler_config = SchedulerConfig(100, 64, 1) @@ -71,20 +96,52 @@ def test_scheduler_schedule_simple(): # Schedule seq groups prompts. num_tokens = block_size * num_seq_group - seq_group_meta, out = scheduler.schedule() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == num_seq_group + append_new_token(out, 1) # Schedule seq groups generation. - seq_group_meta, out = scheduler.schedule() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_seq_group assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == num_seq_group + append_new_token(out, 1) + + +def test_scheduler_prefill_prioritized(): + """Verify running batched tokens are not applied to prefill requests.""" + block_size = 4 + max_model_len = 30 + max_batched_num_tokens = 30 + scheduler_config = SchedulerConfig(max_batched_num_tokens, 2, + max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 2 + cache_config.num_gpu_blocks = 2 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + _, seq_group_a = create_dummy_prompt("1", 1) + scheduler.add_seq_group(seq_group_a) + + # Schedule seq groups prompts. + _, out = schedule_and_update_computed_tokens(scheduler) + assert get_sequence_groups(out) == [seq_group_a] + + # Add a new prefill request B. + _, seq_group_b = create_dummy_prompt("2", 30) + scheduler.add_seq_group(seq_group_b) + + # Verify prefill requests are prioritized. Since max_batched_num_tokens + # is 1, new prefill request has to be scheduled first. + _, out = schedule_and_update_computed_tokens(scheduler) + assert get_sequence_groups(out) == [seq_group_b] def test_scheduler_schedule_preempt_abort(): @@ -103,7 +160,7 @@ def test_scheduler_schedule_preempt_abort(): scheduler.add_seq_group(seq_group_b) # Schedule seq groups prompts. - seq_group_meta, out = scheduler.schedule() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert get_sequence_groups(out) == [seq_group_a, seq_group_b] assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in @@ -113,12 +170,10 @@ def test_scheduler_schedule_preempt_abort(): # Append "generated" tokens, allowing the sequence to mark prompt tokens as # processed. - token_id = 0 - seq_a.append_token_id(token_id, {token_id: Logprob(0.0)}) - seq_b.append_token_id(token_id, {token_id: Logprob(0.0)}) + append_new_token(out, 1) # Schedule seq groups generation and preempt seq group b. - seq_group_meta, out = scheduler.schedule() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert get_sequence_groups(out) == [seq_group_a] assert out.num_batched_tokens == 1 assert (not out.blocks_to_copy and not out.blocks_to_swap_in @@ -128,7 +183,7 @@ def test_scheduler_schedule_preempt_abort(): # Abort seq group a. Re-schedule seq group b prompt with recomputation. scheduler.abort_seq_group("1") - seq_group_meta, out = scheduler.schedule() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert get_sequence_groups(out) == [seq_group_b] assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in @@ -158,12 +213,14 @@ def test_scheduler_max_seqs(): scheduler.add_seq_group(all_seq_groups[0]) # Schedule seq groups prompts. - _, out = scheduler.schedule() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) + append_new_token(out, 1) # Schedule seq groups generation. - _, out = scheduler.schedule() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) + append_new_token(out, 1) # Append 2 more seq group scheduler.add_seq_group(all_seq_groups[1]) @@ -172,12 +229,11 @@ def test_scheduler_max_seqs(): # Schedule seq groups prompts. # Only 1 seq group should be scheduled since max_seq_group is 2 # and one is prompting. - _, out = scheduler.schedule() + _, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) def test_scheduler_delay_factor(): - block_size = 4 scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5) cache_config = CacheConfig(block_size, 1.0, 1, "auto") @@ -186,24 +242,630 @@ def test_scheduler_delay_factor(): scheduler = Scheduler(scheduler_config, cache_config, None) # schedule first prompt - _, seq_group = create_dummy_prompt("0", prompt_length=block_size) + seq_group_meta, seq_group = create_dummy_prompt("0", + prompt_length=block_size) scheduler.add_seq_group(seq_group) - seq_group_meta, out = scheduler.schedule() - assert out.prompt_run + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert out.num_prefill_groups > 0 assert seq_group_meta[0].request_id == '0' + append_new_token(out, 1) # wait for a second before scheduling next prompt time.sleep(1) - _, seq_group = create_dummy_prompt("1", prompt_length=block_size) + seq_group_meta, seq_group = create_dummy_prompt("1", + prompt_length=block_size) scheduler.add_seq_group(seq_group) # second prompt should *not* be scheduled - seq_group_meta, out = scheduler.schedule() - assert not out.prompt_run + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert out.num_prefill_groups == 0 assert seq_group_meta[0].request_id == '0' + append_new_token(out, 1) # wait for more than 0.5 second and try again time.sleep(0.6) - seq_group_meta, out = scheduler.schedule() - assert out.prompt_run + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert out.num_prefill_groups > 0 assert seq_group_meta[0].request_id == '1' + append_new_token(out, 1) + + +def test_swapped_out_prioritized(): + scheduler = initialize_scheduler(max_num_seqs=6) + # best_of=2 * 3 == 6 sequences. + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + scheduler.add_seq_group(seq_group) + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 3 + append_new_token(out, 1) + + # The last request should be swapped out. + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "2" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 2 + assert out.num_batched_tokens == 2 + assert out.blocks_to_swap_out != {} + assert out.blocks_to_swap_in == {} + append_new_token(out, 1) + + # Add 1 more task. Swap should be prioritized over prefill. + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + scheduler.add_seq_group(seq_group) + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + append_new_token(out, 1) + assert len(out.scheduled_seq_groups) == 3 + # 3 decodes. It is swapped in. + assert out.num_batched_tokens == 3 + assert out.blocks_to_swap_in != {} + assert out.blocks_to_swap_out == {} + + +def initialize_scheduler(*, + max_num_seqs=1000, + max_token_budget=1000, + max_model_len=1000, + lora_config=None): + block_size = 4 + scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs, + max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, lora_config) + return scheduler + + +def create_token_budget(token_budget: int = 10000, + max_num_seqs: int = 10000) -> SchedulingBudget: + return SchedulingBudget( + token_budget=token_budget, + max_num_seqs=max_num_seqs, + ) + + +def add_token_budget(budget: SchedulingBudget, + num_batched_tokens: int = 0, + num_curr_seqs: int = 0): + mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1] + budget.add_num_batched_tokens(mock_seq_group.request_id, + num_batched_tokens) + budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) + + +def test_prefill_schedule_max_prompt_len(): + """ + Test prompt longer than max_prompt_len is aborted. + """ + scheduler = initialize_scheduler(max_model_len=30) + _, seq_group = create_dummy_prompt(0, prompt_length=60) + waiting = deque([seq_group]) + budget = create_token_budget() + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 1 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(remaining_waiting) == 0 + + +def test_prefill_schedule_token_budget(): + """ + Test token budget respected. + """ + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget(token_budget=0) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + + # 0 token budget == nothing is scheduled. + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(remaining_waiting) == 2 + + # 60 token budget == 1 request scheduled. + budget = create_token_budget(token_budget=60) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 1 + assert budget.num_batched_tokens == 60 + assert budget.num_curr_seqs == 1 + assert len(remaining_waiting) == 1 + + # Test when current_batched_tokens respected. + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget(token_budget=60) + add_token_budget(budget, 30, 0) + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + # Cannot schedule a prompt that doesn't fit the budget. + waiting.append(seq_group) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 30 + assert budget.num_curr_seqs == 0 + assert len(remaining_waiting) == 1 + budget = create_token_budget(token_budget=90) + add_token_budget(budget, 30, 0) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.seq_groups) == 1 + assert budget.num_batched_tokens == 90 + assert budget.num_curr_seqs == 1 + assert len(remaining_waiting) == 0 + + +def test_prefill_schedule_max_seqs(): + """ + Test max seq respected. + """ + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget(max_num_seqs=2) + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 2 + assert budget.num_batched_tokens == 120 + assert budget.num_curr_seqs == 2 + assert len(remaining_waiting) == 1 + + # Verify curr_num_seqs respected. + waiting = deque() + budget = create_token_budget(max_num_seqs=2) + add_token_budget(budget, 0, 2) + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 2 + assert len(remaining_waiting) == 1 + + +def test_prefill_schedule_max_lora(): + """ + Test max lora is respected and prioritized. + """ + lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) + scheduler = initialize_scheduler(lora_config=lora_config) + waiting = deque() + budget = create_token_budget(token_budget=120) + curr_loras = set() + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + lora_request=LoRARequest( + lora_name=str(i), + lora_int_id=i + 1, + lora_local_path="abc")) + waiting.append(seq_group) + # Add two more requests to verify lora is prioritized. + # 0: Lora, 1: Lora, 2: regular, 3: regular + # In the first iteration, index 0, 2 is scheduled. + # If a request is not scheduled because it hits max lora, it is + # prioritized. Verify that. + for i in range(2, 4): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + # Schedule 2 requests (0 and 2) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, curr_loras) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 2 + assert budget.num_batched_tokens == 120 + assert budget.num_curr_seqs == 2 + assert len(remaining_waiting) == 2 + assert len(curr_loras) == 1 + # The second lora request is scheduled next as FCFS policy. + # Reset curr_loras so that it can be scheduled. + curr_loras = set() + budget = create_token_budget(token_budget=60) + remaining_waiting, output = scheduler._schedule_prefills( + remaining_waiting, budget, curr_loras) + assert len(output.seq_groups) == 1 + assert output.seq_groups[0].seq_group.request_id == "1" + assert len(remaining_waiting) == 1 + assert len(curr_loras) == 1 + assert budget.num_batched_tokens == 60 + + +def test_prefill_schedule_no_block_manager_capacity(): + """ + Test sequence cannot be scheduled due to block manager has no capacity. + """ + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget() + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + scheduler.block_manager.can_allocate = MagicMock() + scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER + remainig_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(remainig_waiting) == 3 + + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget() + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + scheduler.block_manager.can_allocate = MagicMock() + scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 3 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(remaining_waiting) == 0 + + +def test_decode_schedule_preempted(): + """ + Test decodes cannot be scheduled and preempted. + """ + scheduler = initialize_scheduler() + running = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + scheduler._allocate_and_set_running(seq_group, 60) + append_new_token_seq_group(60, seq_group, 1) + running.append(seq_group) + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "1" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + + # 1 cannot be scheduled, and the lowest priority (request 2) + # should be preempted. 1 will also be preempted. + budget = create_token_budget() + remainig_running, output = scheduler._schedule_running( + running, budget, curr_loras, policy) + assert len(remainig_running) == 0 + assert len(output.decode_seq_groups) == 1 + assert len(output.prefill_seq_groups) == 0 + assert output.decode_seq_groups[0].seq_group.request_id == "0" + assert len(output.preempted) == 2 + # Verify budgets are updated. + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 1 + # Both should be preempted, not swapped. + assert output.blocks_to_swap_out == {} + # Nothing is copied. + assert output.blocks_to_copy == {} + + +def test_decode_swap_beam_search(): + """ + Test best_of > 1 swap out blocks + """ + scheduler = initialize_scheduler() + running = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + budget = create_token_budget() + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group, 60) + running.append(seq_group) + append_new_token_seq_group(60, seq_group, 1) + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + budget.add_num_batched_tokens( + seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING)) + + # The last request should be swapped out. + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "2" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + scheduler.block_manager.swap_out = MagicMock() + expected_swap_mapping = {"5": "7"} + scheduler.block_manager.swap_out.return_value = expected_swap_mapping + + remainig_running, output = scheduler._schedule_running( + running, budget, curr_loras, policy) + assert len(remainig_running) == 0 + assert len(output.decode_seq_groups) == 2 + assert len(output.prefill_seq_groups) == 0 + assert output.decode_seq_groups[0].seq_group.request_id == "0" + assert output.decode_seq_groups[1].seq_group.request_id == "1" + assert len(output.preempted) == 0 + assert len(output.swapped_out) == 1 + # Budget should refledct preempted requests. + assert budget.num_batched_tokens == 2 + # since there are 2 sequences, 2 should be subtracted. + assert budget.num_curr_seqs == 4 + # Both should be preempted, not swapped. + assert output.blocks_to_swap_out == expected_swap_mapping + # Nothing is copied. + assert output.blocks_to_copy == {} + + +def test_schedule_decode_blocks_to_copy_update(): + """ + Verify blocks_to_copy is updated. + """ + scheduler = initialize_scheduler() + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + running = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + scheduler._allocate_and_set_running(seq_group, 60) + append_new_token_seq_group(60, seq_group, 1) + running.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.append_slots = MagicMock() + scheduler.block_manager.append_slots.return_value = {2: [3]} + + budget = create_token_budget() + remaining_running, output = scheduler._schedule_running( + running, budget, curr_loras, policy) + assert len(remaining_running) == 0 + assert len(output.decode_seq_groups) == 1 + assert len(output.prefill_seq_groups) == 0 + assert len(output.preempted) == 0 + assert len(output.swapped_out) == 0 + # Nothing is preempted. + assert output.blocks_to_swap_out == {} + # Since append_slot returns the source -> dist mapping, it should + # applied. + assert output.blocks_to_copy == {2: [3]} + + +def test_schedule_swapped_simple(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group, 60) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 0 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 2 + assert len(output.decode_seq_groups) == 1 + assert len(output.prefill_seq_groups) == 0 + # swap in is the reverse of swap out + blocks_to_swap_in_reverse = {} + for swapin, swapout in output.blocks_to_swap_in.items(): + blocks_to_swap_in_reverse[swapout] = swapin + assert blocks_to_swap_out == blocks_to_swap_in_reverse + + +def test_schedule_swapped_max_token_budget(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group, 60) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + budget = create_token_budget(token_budget=1) + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 1 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 2 + assert len(output.decode_seq_groups) == 1 + assert len(output.prefill_seq_groups) == 0 + + # Verify num_batched_tokens are respected. + budget = create_token_budget(token_budget=1) + add_token_budget(budget, 1, 0) + remaining_swapped, output = scheduler._schedule_swapped( + remaining_swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 1 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 0 + assert len(output.decode_seq_groups) == 0 + assert len(output.prefill_seq_groups) == 0 + + +def test_schedule_swapped_max_seqs(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for i in range(4): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + scheduler._allocate_and_set_running(seq_group, 60) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + budget = create_token_budget(max_num_seqs=2) + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 2 + assert budget.num_batched_tokens == 2 + assert budget.num_curr_seqs == 2 + assert len(output.decode_seq_groups) == 2 + assert len(output.prefill_seq_groups) == 0 + + # Verify num_curr_seqs are respected. + remaining_swapped, output = scheduler._schedule_swapped( + remaining_swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 2 + assert budget.num_batched_tokens == 2 + assert budget.num_curr_seqs == 2 + assert len(output.decode_seq_groups) == 0 + assert len(output.prefill_seq_groups) == 0 + + +def test_schedule_swapped_max_loras(): + lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) + scheduler = initialize_scheduler(lora_config=lora_config) + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = set() + blocks_to_swap_out = {} + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + lora_request=LoRARequest( + lora_name=str(i), + lora_int_id=i + 1, + lora_local_path="abc")) + scheduler._allocate_and_set_running(seq_group, 60) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 1 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 1 + assert len(output.decode_seq_groups) == 1 + assert len(output.prefill_seq_groups) == 0 + assert len(curr_loras) == 1 + + +def test_schedule_swapped_cannot_swap_in(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group, 60) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.can_swap_in = MagicMock() + scheduler.block_manager.can_swap_in.return_value = False + # Since we cannot swap in, none of the requests are swapped in. + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 2 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(output.decode_seq_groups) == 0 + assert len(output.prefill_seq_groups) == 0 + + +def test_schedule_swapped_blocks_to_copy(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group, 60) + append_new_token_seq_group(60, seq_group, 1) + blocks_to_swap_out = {} + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.append_slots = MagicMock() + scheduler.block_manager.append_slots.return_value = {2: [3]} + + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 0 + assert len(output.decode_seq_groups) == 1 + assert len(output.prefill_seq_groups) == 0 + assert output.blocks_to_copy == {2: [3]} + + +def test_scheduling_budget(): + TOKEN_BUDGET = 4 + MAX_SEQS = 4 + budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS) + assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1) + assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4) + assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5) + assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1) + assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5) + assert budget.remaining_token_budget() == TOKEN_BUDGET + + # Verify add/subtract num batched tokens. + _, seq_group = create_dummy_prompt("1", 3) + budget.add_num_batched_tokens(seq_group.request_id, 2) + assert budget.remaining_token_budget() == 2 + assert budget.num_batched_tokens == 2 + assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1) + assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1) + # Verify adding another seq group is no-op. + budget.add_num_batched_tokens(seq_group.request_id, 2) + assert budget.remaining_token_budget() == 2 + assert budget.num_batched_tokens == 2 + budget.subtract_num_batched_tokens(seq_group.request_id, 2) + assert budget.remaining_token_budget() == 4 + assert budget.num_batched_tokens == 0 + budget.subtract_num_batched_tokens(seq_group.request_id, 2) + assert budget.remaining_token_budget() == 4 + assert budget.num_batched_tokens == 0 + + # Verify add/subtract max seqs. + _, seq_group = create_dummy_prompt("1", 3) + budget.add_num_seqs(seq_group.request_id, 2) + assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2) + assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3) + assert budget.num_curr_seqs == 2 + # Verify adding another seq group is no-op. + budget.add_num_seqs(seq_group.request_id, 2) + assert budget.num_curr_seqs == 2 + budget.subtract_num_seqs(seq_group.request_id, 2) + assert budget.num_curr_seqs == 0 + budget.subtract_num_seqs(seq_group.request_id, 2) + assert budget.num_curr_seqs == 0 diff --git a/tests/core/utils.py b/tests/core/utils.py index 2e462f2aec4d4..fbbdb07cb8e6e 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,14 +1,19 @@ import time -from typing import Tuple +from typing import Optional, Tuple from vllm import SamplingParams +from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, Sequence, SequenceGroup def create_dummy_prompt( - request_id: str, - prompt_length: int, - block_size: int = None) -> Tuple[Sequence, SequenceGroup]: + request_id: str, + prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> Tuple[Sequence, SequenceGroup]: if not block_size: block_size = prompt_length @@ -17,14 +22,16 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup(request_id, [prompt], SamplingParams(), - time.time(), None) + seq_group = SequenceGroup( + request_id, [prompt], + SamplingParams(use_beam_search=use_beam_search, best_of=best_of), + time.time(), lora_request) return prompt, seq_group def create_seq_group( - seq_prompt_lens=1024, + seq_prompt_len=1024, seq_output_lens=(128, ), request_id='0', seq_id_start=0, @@ -32,7 +39,7 @@ def create_seq_group( assert len(seq_output_lens) > 0 - prompt_token_ids = [0] * seq_prompt_lens + prompt_token_ids = [0] * seq_prompt_len seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): diff --git a/tests/engine/test_detokenization.py b/tests/engine/test_detokenization.py new file mode 100644 index 0000000000000..f77f6d0725b6b --- /dev/null +++ b/tests/engine/test_detokenization.py @@ -0,0 +1,32 @@ +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_computed_prefix_blocks(model: str): + # This test checks if the engine generates completions both with and + # without optional detokenization, that detokenization includes text + # and no-detokenization doesn't, and that both completions have the same + # token_ids. + prompt = ( + "You are a helpful assistant. How do I build a car from cardboard and " + "paper clips? Is there an easy to follow video tutorial available " + "online for free?") + + llm = LLM(model=model) + sampling_params = SamplingParams(max_tokens=10, + temperature=0.0, + detokenize=False) + + outputs_no_detokenization = llm.generate(prompt, + sampling_params)[0].outputs[0] + sampling_params.detokenize = True + outputs_with_detokenization = llm.generate(prompt, + sampling_params)[0].outputs[0] + + assert outputs_no_detokenization.text == '' + assert outputs_with_detokenization.text != '' + assert outputs_no_detokenization.token_ids == \ + outputs_with_detokenization.token_ids diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py new file mode 100644 index 0000000000000..22e65bf7e7da1 --- /dev/null +++ b/tests/entrypoints/test_server_oot_registration.py @@ -0,0 +1,66 @@ +import multiprocessing +import sys +import time + +import torch +from openai import OpenAI, OpenAIError + +from vllm import ModelRegistry +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.utils import get_open_port + + +class MyOPTForCausalLM(OPTForCausalLM): + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + logits.zero_() + logits[:, 0] += 1.0 + return logits + + +def server_function(port): + # register our dummy model + ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) + sys.argv = ["placeholder.py"] + \ + ("--model facebook/opt-125m --dtype" + f" float32 --api-key token-abc123 --port {port}").split() + import runpy + runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') + + +def test_oot_registration_for_api_server(): + port = get_open_port() + server = multiprocessing.Process(target=server_function, args=(port, )) + server.start() + client = OpenAI( + base_url=f"http://localhost:{port}/v1", + api_key="token-abc123", + ) + while True: + try: + completion = client.chat.completions.create( + model="facebook/opt-125m", + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, + ) + break + except OpenAIError as e: + if "Connection error" in str(e): + time.sleep(3) + else: + raise e + server.kill() + generated_text = completion.choices[0].message.content + # make sure only the first token is generated + rest = generated_text.replace("", "") + assert rest == "" diff --git a/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json new file mode 100644 index 0000000000000..a548f0a9611f6 --- /dev/null +++ b/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json @@ -0,0 +1,90 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.0230364128947258, + "1": 0.01979283057153225, + "2": 0.0241350457072258, + "3": 0.0308314748108387, + "4": 0.0430733822286129, + "5": 0.0370396226644516, + "6": 0.0306222103536129, + "7": 0.0357491634786129, + "8": 0.0358189195394516, + "9": 0.0443289652466774, + "10": 0.0433175228536129, + "11": 0.0416782945394516, + "12": 0.0366908498108387, + "13": 0.0432477705180645, + "14": 0.0410505048930645, + "15": 0.0457589291036129, + "16": 0.0418526791036129, + "17": 0.0432477705180645, + "18": 0.0469447560608387, + "19": 0.0514787957072258, + "20": 0.0541294664144516, + "21": 0.0587681382894516, + "22": 0.0625, + "23": 0.0585588738322258, + "24": 0.0600237175822258, + "25": 0.0588030144572258, + "26": 0.0531180277466774, + "27": 0.06396484375, + "28": 0.0603027381002903, + "29": 0.0582101047039032, + "30": 0.0625348836183548, + "31": 0.0585588738322258, + "32": 0.0582798570394516, + "33": 0.0575125589966774, + "34": 0.0590820349752903, + "35": 0.0614188089966774, + "36": 0.0631975457072258, + "37": 0.0615931935608387, + "38": 0.0601283498108387, + "39": 0.0571986623108387, + "40": 0.0670340433716774, + "41": 0.0523507259786129, + "42": 0.0547223798930645, + "43": 0.0631975457072258, + "44": 0.0663713738322258, + "45": 0.0603376142680645, + "46": 0.0652204304933548, + "47": 0.0734514519572258, + "48": 0.0693708211183548, + "49": 0.0725446492433548, + "50": 0.0627790242433548, + "51": 0.0691266804933548, + "52": 0.0688825398683548, + "53": 0.068429134786129, + "54": 0.0605119988322258, + "55": 0.0799386203289032, + "56": 0.0853097140789032, + "57": 0.0661969929933548, + "58": 0.0689871683716774, + "59": 0.0724051371216774, + "60": 0.0541643425822258, + "61": 0.0626743882894516, + "62": 0.0628487765789032, + "63": 0.0607212632894516, + "64": 0.0589076466858387, + "65": 0.0451660193502903, + "66": 0.0453055277466774, + "67": 0.0414341539144516, + "68": 0.0385044664144516, + "69": 0.0414341539144516, + "70": 0.0466308631002903, + "71": 0.0399693101644516, + "72": 0.0437011756002903, + "73": 0.0434221550822258, + "74": 0.0428989976644516, + "75": 0.0401785746216774, + "76": 0.0431082621216774, + "77": 0.0484444759786129, + "78": 0.0417829267680645, + "79": 0.0418178029358387 + } + } + } +} \ No newline at end of file diff --git a/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json new file mode 100644 index 0000000000000..bb734039e982b --- /dev/null +++ b/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json @@ -0,0 +1,42 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.0152239128947258, + "1": 0.0188860222697258, + "2": 0.0354178324341774, + "3": 0.0376674123108387, + "4": 0.0418526791036129, + "5": 0.0433175228536129, + "6": 0.0397600457072258, + "7": 0.0424455925822258, + "8": 0.0415387861430645, + "9": 0.0408412404358387, + "10": 0.0395856611430645, + "11": 0.0377371683716774, + "12": 0.0400739423930645, + "13": 0.040771484375, + "14": 0.0393415205180645, + "15": 0.0369001142680645, + "16": 0.03857421875, + "17": 0.0387486070394516, + "18": 0.0403180830180645, + "19": 0.0396205373108387, + "20": 0.0375627800822258, + "21": 0.0407366082072258, + "22": 0.0432477705180645, + "23": 0.0377022884786129, + "24": 0.0399693101644516, + "25": 0.0374581478536129, + "26": 0.0413295216858387, + "27": 0.0442243330180645, + "28": 0.0424804724752903, + "29": 0.0456891767680645, + "30": 0.0409109964966774, + "31": 0.0482352152466774 + } + } + } +} diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b03fecffdc645..03ea72924921e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -32,7 +32,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256 BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] -KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] +KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -172,6 +172,9 @@ def test_paged_attention( device) key_cache, value_cache = key_caches[0], value_caches[0] + # Using default kv_scale + kv_scale = 1.0 + # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": @@ -188,6 +191,7 @@ def test_paged_attention( max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) elif version == "v2": num_partitions = ((max_context_len + PARTITION_SIZE - 1) // @@ -219,12 +223,13 @@ def test_paged_attention( max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) else: raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. - if kv_cache_dtype == "fp8_e5m2": + if kv_cache_dtype == "fp8": # Convert cache data back to dtype. x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, @@ -232,14 +237,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache) + cache_ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache) + cache_ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) @@ -263,7 +268,8 @@ def test_paged_attention( # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. - if kv_cache_dtype == "fp8_e5m2": + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": atol, rtol = 1e-2, 1e-5 assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 0cdb92f2d9700..4141aacafd0b2 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,6 +5,7 @@ import pytest import torch from vllm._C import cache_ops +from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -23,7 +24,7 @@ SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] +KV_CACHE_DTYPE = ["auto", "fp8"] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -105,6 +106,7 @@ def test_copy_blocks( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @torch.inference_mode() def test_reshape_and_cache( kv_cache_factory, @@ -116,7 +118,10 @@ def test_reshape_and_cache( dtype: torch.dtype, seed: int, device: str, + kv_cache_dtype: str, ) -> None: + if not is_hip() and kv_cache_dtype == "fp8": + pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -132,17 +137,33 @@ def test_reshape_and_cache( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, - num_heads, head_size, dtype, - None, seed, device) + num_heads, head_size, + kv_cache_dtype, dtype, seed, + device) key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() + if kv_cache_dtype == "fp8": + cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + cache_ops.convert_fp8(key_cache, cloned_key_cache) + cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + cache_ops.convert_fp8(value_cache, cloned_value_cache) + else: + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Using default kv_scale + kv_scale = 1.0 # Call the reshape_and_cache kernel. cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, "auto") + slot_mapping, kv_cache_dtype, kv_scale) + + if kv_cache_dtype == "fp8": + result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + cache_ops.convert_fp8(key_cache, result_key_cache) + result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + cache_ops.convert_fp8(value_cache, result_value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -156,8 +177,18 @@ def test_reshape_and_cache( cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i] - assert torch.allclose(key_cache, cloned_key_cache) - assert torch.allclose(value_cache, cloned_value_cache) + if kv_cache_dtype == "fp8": + assert torch.allclose(result_key_cache, + cloned_key_cache, + atol=0.001, + rtol=0.1) + assert torch.allclose(result_value_cache, + cloned_value_cache, + atol=0.001, + rtol=0.1) + else: + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) @pytest.mark.parametrize("direction", COPYING_DIRECTION) @@ -169,6 +200,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @torch.inference_mode() def test_swap_blocks( kv_cache_factory, @@ -181,7 +213,12 @@ def test_swap_blocks( dtype: torch.dtype, seed: int, device: str, + kv_cache_dtype: str, ) -> None: + if kv_cache_dtype == "fp8" and "cpu" in direction: + pytest.skip() + if not is_hip() and kv_cache_dtype == "fp8": + pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -202,13 +239,13 @@ def test_swap_blocks( # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed, - src_device) + num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, + seed, src_device) # Create the KV caches on the second device. dist_key_caches, dist_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed, - dst_device) + num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, + seed, dst_device) src_key_caches_clone = src_key_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone() @@ -223,3 +260,40 @@ def test_swap_blocks( dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) + + +@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3") +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_fp8_conversion( + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + low = -224.0 + high = 224.0 + shape = (num_blocks, num_heads, head_size, block_size) + cache = torch.empty(shape, dtype=dtype, device=device) + cache.uniform_(low, high) + + cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) + cache_ops.convert_fp8(cache, cache_fp8) + + converted_cache = torch.empty_like(cache) + cache_ops.convert_fp8(cache_fp8, converted_cache) + + assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py new file mode 100644 index 0000000000000..35ad7342944cd --- /dev/null +++ b/tests/lora/test_lora_checkpoints.py @@ -0,0 +1,40 @@ +import pytest + +from vllm.lora.models import LoRAModel +from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM + + +@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"]) +def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): + supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules + packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping + embedding_modules = BaiChuanBaseForCausalLM.embedding_modules + embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules + expected_lora_modules = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + if lora_name == "baichuan7B": + # For the baichuan7B model, load it's LoRA, + # and the test should pass. + LoRAModel.from_local_checkpoint( + baichuan_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) + else: + # For the baichuan7B model, load chatglm3-6b's LoRA, + # and the test should raise the following error. + expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 + with pytest.raises(ValueError, match=expected_error): + LoRAModel.from_local_checkpoint( + chatglm3_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 60aa90fe4ee8a..54594690f7922 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -3,8 +3,8 @@ import random import tempfile from unittest.mock import patch -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -27,6 +27,10 @@ def test_worker_apply_lora(sql_lora_files): parallel_config=ParallelConfig(1, 1, False), scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), + cache_config=CacheConfig(block_size=16, + gpu_memory_utilization=1., + swap_space=0, + cache_dtype="auto"), local_rank=0, rank=0, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py new file mode 100644 index 0000000000000..3154f2826d10c --- /dev/null +++ b/tests/model_executor/weight_utils.py @@ -0,0 +1,26 @@ +import os + +import huggingface_hub.constants +import pytest + +from vllm.model_executor.weight_utils import enable_hf_transfer + + +def test_hf_transfer_auto_activation(): + if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ: + # in case it is already set, we can't test the auto activation + pytest.skip( + "HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation") + enable_hf_transfer() + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + HF_TRANFER_ACTIVE = True + except ImportError: + HF_TRANFER_ACTIVE = False + assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == + HF_TRANFER_ACTIVE) + + +if __name__ == "__main__": + test_hf_transfer_auto_activation() diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py new file mode 100644 index 0000000000000..50ab06631500b --- /dev/null +++ b/tests/models/test_oot_registration.py @@ -0,0 +1,32 @@ +import torch + +from vllm import LLM, ModelRegistry, SamplingParams +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class MyOPTForCausalLM(OPTForCausalLM): + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + logits.zero_() + logits[:, 0] += 1.0 + return logits + + +def test_oot_registration(): + # register our dummy model + ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) + prompts = ["Hello, my name is", "The text does not matter"] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model="facebook/opt-125m") + first_token = llm.get_tokenizer().decode(0) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + generated_text = output.outputs[0].text + # make sure only the first token is generated + rest = generated_text.replace(first_token, "") + assert rest == "" diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py new file mode 100644 index 0000000000000..cd64622e2226f --- /dev/null +++ b/tests/quantization/test_autogptq_marlin_configs.py @@ -0,0 +1,68 @@ +"""Tests whether Marlin models can be loaded from the autogptq config. + +Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`. +""" + +from dataclasses import dataclass + +import pytest + +from vllm.config import ModelConfig + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +# Model Id // Expected Kernel +MODELS_QUANT_TYPE = [ + # compat: autogptq <=0.7.1 is_marlin_format: bool + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"), + # compat: autogptq >=0.8.0 use checkpoint_format: str + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq") +] + + +@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE) +def test_auto_gptq(model_quant_type: str, ) -> None: + model_path, quant_type = model_quant_type + + model_config_no_quant_arg = ModelConfig( + model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + quantization=None # case 1 + ) + + model_config_quant_arg = ModelConfig( + model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + quantization="gptq" # case 2 + ) + + assert model_config_no_quant_arg.quantization == quant_type, ( + f"Expected quant_type == {quant_type} for {model_path}, " + f"but found {model_config_no_quant_arg.quantization} " + "for no --quantization None case") + + assert model_config_quant_arg.quantization == quant_type, ( + f"Expected quant_type == {quant_type} for {model_path}, " + f"but found {model_config_quant_arg.quantization} " + "for --quantization gptq case") diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py new file mode 100644 index 0000000000000..3788e9e9752ff --- /dev/null +++ b/tests/samplers/test_logits_processor.py @@ -0,0 +1,62 @@ +import pytest +import torch + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_logits_processor_force_generate( + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + tokenizer = vllm_model.model.get_tokenizer() + repeat_times = 2 + enforced_answers = " vLLM" + vllm_token_ids = tokenizer.encode(enforced_answers, + add_special_tokens=False) + max_tokens = len(vllm_token_ids) * repeat_times + + def pick_vllm(token_ids, logits): + token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)] + logits[token_id] = torch.finfo(logits.dtype).max + return logits + + params_with_logprobs = SamplingParams( + logits_processors=[pick_vllm], + prompt_logprobs=3, + max_tokens=max_tokens, + ) + + # test logits_processors when prompt_logprobs is not None + vllm_model.model._add_request( + prompt=example_prompts[0], + sampling_params=params_with_logprobs, + prompt_token_ids=None, + ) + + # test prompt_logprobs is not None + vllm_model.model._add_request( + prompt=example_prompts[1], + sampling_params=SamplingParams( + prompt_logprobs=3, + max_tokens=max_tokens, + ), + prompt_token_ids=None, + ) + + # test grouped requests + vllm_model.model._add_request( + prompt=example_prompts[2], + sampling_params=SamplingParams(max_tokens=max_tokens), + prompt_token_ids=None, + ) + + outputs = vllm_model.model._run_engine(False) + + assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py new file mode 100644 index 0000000000000..1d99cb5d32219 --- /dev/null +++ b/tests/spec_decode/e2e/conftest.py @@ -0,0 +1,41 @@ +import pytest + +from tests.conftest import cleanup +from vllm import LLM +from vllm.model_executor.utils import set_random_seed + + +@pytest.fixture +def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, seed): + return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, seed) + + +@pytest.fixture +def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + test_llm_kwargs, seed): + return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + test_llm_kwargs, seed) + + +def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + distinct_llm_kwargs, seed): + kwargs = { + **common_llm_kwargs, + **per_test_common_llm_kwargs, + **distinct_llm_kwargs, + } + + def generator_inner(): + llm = LLM(**kwargs) + + set_random_seed(seed) + + yield llm + del llm + cleanup() + + for llm in generator_inner(): + yield llm + del llm diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py new file mode 100644 index 0000000000000..b5a6fcb7900a3 --- /dev/null +++ b/tests/spec_decode/e2e/test_correctness.py @@ -0,0 +1,50 @@ +import pytest + +from vllm import SamplingParams + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + "speculative_model": "facebook/opt-125m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_config(test_llm_generator): + output_len = 1024 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises( + AssertionError, + match="Speculative decoding not yet supported for GPU backend"): + get_token_ids_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): + for llm in llm_generator: + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + del llm + + return token_ids diff --git a/tests/spec_decode/test_batch_expansion.py b/tests/spec_decode/test_batch_expansion.py index 80a960acf0be5..43cfd78ddb0cc 100644 --- a/tests/spec_decode/test_batch_expansion.py +++ b/tests/spec_decode/test_batch_expansion.py @@ -7,6 +7,7 @@ from .utils import create_seq_group_metadata_from_prompts, mock_worker @pytest.mark.parametrize('num_target_seq_ids', [100]) +@pytest.mark.skip_global_cleanup def test_create_target_seq_id_iterator(num_target_seq_ids: int): """Verify all new sequence ids are greater than all input seq ids. @@ -27,6 +28,7 @@ def test_create_target_seq_id_iterator(num_target_seq_ids: int): @pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.skip_global_cleanup def test_get_token_ids_to_score(k: int): """Verify correct tokens are selected for scoring. """ @@ -53,6 +55,7 @@ def test_get_token_ids_to_score(k: int): @pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.skip_global_cleanup def test_create_single_target_seq_group_metadata(k: int): """Verify correct creation of a batch-expanded seq group metadata. """ diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 87d3716ca98d7..47aff8f575413 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -487,7 +487,7 @@ def test_empty_input_batch(k: int, batch_size: int): **execute_model_data.to_dict()) -@torch.inference_mode() +@pytest.mark.skip_global_cleanup def test_init_device(): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. @@ -512,8 +512,8 @@ def test_init_device(): @torch.inference_mode() -def test_init_cache_engine(): - """Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer +def test_initialize_cache(): + """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer workers. """ draft_worker = mock_worker(cls=MultiStepWorker) @@ -525,23 +525,22 @@ def test_init_cache_engine(): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - cache_config = MagicMock() + kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} + worker.initialize_cache(**kwargs) - worker.init_cache_engine(cache_config) - - draft_worker.init_cache_engine.assert_called_once_with(cache_config) - target_worker.init_cache_engine.assert_called_once_with(cache_config) + draft_worker.initialize_cache.assert_called_once_with(**kwargs) + target_worker.initialize_cache.assert_called_once_with(**kwargs) @pytest.mark.parametrize('available_gpu_blocks', [1, 1024]) @pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) -@torch.inference_mode() -def test_profile_num_available_blocks(available_gpu_blocks: int, - available_cpu_blocks: int, - target_cache_block_size_bytes: int, - draft_kv_size_bytes: int): +@pytest.mark.skip_global_cleanup +def test_determine_num_available_blocks(available_gpu_blocks: int, + available_cpu_blocks: int, + target_cache_block_size_bytes: int, + draft_kv_size_bytes: int): """Verify SpecDecodeWorker correctly profiles num available GPU blocks. Specifically, it should run profiling in the scorer worker, and then evenly split the blocks between proposer and scorer worker. @@ -552,7 +551,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) - target_worker.profile_num_available_blocks.return_value = ( + target_worker.determine_num_available_blocks.return_value = ( available_gpu_blocks, available_cpu_blocks) target_worker.get_cache_block_size_bytes.return_value = ( target_cache_block_size_bytes) @@ -561,17 +560,9 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - # These values do not directly impact the adjusted block size calculation, - # so they can be fixed. - gpu_memory_utilization = 0.9 - cpu_swap_space = 100 - block_size = 16 + num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() - num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks( - block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto") - - target_worker.profile_num_available_blocks.assert_called_once_with( - block_size, gpu_memory_utilization, cpu_swap_space, "auto") + target_worker.determine_num_available_blocks.assert_called_once() assert num_cpu_blocks == available_cpu_blocks assert num_gpu_blocks == split_num_cache_blocks_evenly( @@ -584,7 +575,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096, 2 * 2 * 8192]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) -@torch.inference_mode() +@pytest.mark.skip_global_cleanup def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, target_cache_block_size_bytes: int, draft_kv_size_bytes: int): diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 0cd9a4b1d5815..4637826f254d6 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -107,18 +107,17 @@ def create_worker(cls: type, block_size=block_size, enforce_eager=enforce_eager, ) - - (model_config, cache_config, parallel_config, scheduler_config, - device_config, _, _) = engine_args.create_engine_configs() + engine_config = engine_args.create_engine_config() distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = cls( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -128,10 +127,11 @@ def create_worker(cls: type, worker.init_device() worker.load_model() - cache_config.num_gpu_blocks = num_gpu_blocks - cache_config.num_cpu_blocks = 0 - worker.init_cache_engine(cache_config) - worker.warm_up_model() + engine_config.cache_config.num_gpu_blocks = num_gpu_blocks + engine_config.cache_config.num_cpu_blocks = 0 + worker.initialize_cache( + num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, + num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) return worker diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 1dec928158b16..b16bdc141e57c 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,7 +1,36 @@ +import time +from typing import Optional + import pytest -from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, - SequenceOutput) +from vllm import SamplingParams +from vllm.lora.request import LoRARequest +from vllm.sequence import (SamplerOutput, Sequence, SequenceData, + SequenceGroup, SequenceGroupOutput, SequenceOutput) + + +def create_dummy_prompt( + request_id: str, + prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> SequenceGroup: + if not block_size: + block_size = prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + prompt_tokens = list(range(prompt_length)) + prompt_str = " ".join([str(t) for t in prompt_tokens]) + prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + seq_group = SequenceGroup( + request_id, [prompt], + SamplingParams(use_beam_search=use_beam_search, best_of=best_of), + time.time(), lora_request) + + return seq_group @pytest.fixture @@ -67,6 +96,29 @@ def test_sequence_data_prefill(): # append tokens and reset, simulating recompute seq_data.append_token_id(1, logprob=0.0) - seq_data.reset_num_computed_tokens() + seq_data.reset_state_for_recompute() assert seq_data.get_num_uncomputed_tokens() == 5 assert seq_data.get_num_computed_tokens() == 0 + + +def test_sequence_group_stage(): + seq_group = create_dummy_prompt("1", 12) + assert seq_group.is_prefill() is True + seq_group.update_num_computed_tokens(6) + assert seq_group.is_prefill() is True + seq_group.update_num_computed_tokens(5) + assert seq_group.is_prefill() is True + seq_group.update_num_computed_tokens(1) + assert seq_group.is_prefill() is False + seqs = seq_group.get_seqs() + assert len(seqs) == 1 + seqs[0].data.append_token_id(1, logprob=0.0) + for seq in seq_group.get_seqs(): + seq.reset_state_for_recompute() + assert seq_group.is_prefill() is True + seq_group.update_num_computed_tokens(5) + assert seq_group.is_prefill() is True + seq_group.update_num_computed_tokens(7) + assert seq_group.is_prefill() is True + seq_group.update_num_computed_tokens(1) + assert seq_group.is_prefill() is False diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 92587b40dd45a..9bc9becb2a6f1 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -4,8 +4,8 @@ import pytest from transformers import AutoTokenizer from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer import (Detokenizer, + detokenize_incrementally) from vllm.transformers_utils.tokenizer_group import get_tokenizer_group TRUTH = [ diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 0bbf85f590758..8edb1cf05c08e 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -10,19 +10,19 @@ def test_swap() -> None: engine_args = EngineArgs(model="facebook/opt-125m", dtype="half", load_format="dummy") - (model_config, cache_config, parallel_config, scheduler_config, - device_config, _, _) = engine_args.create_engine_configs() - cache_config.num_gpu_blocks = 100 - cache_config.num_cpu_blocks = 100 + engine_config = engine_args.create_engine_config() + engine_config.cache_config.num_gpu_blocks = 1000 + engine_config.cache_config.num_cpu_blocks = 1000 # Create the worker. distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = Worker( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -32,8 +32,9 @@ def test_swap() -> None: # Initialize the worker. worker.init_device() worker.load_model() - worker.init_cache_engine(cache_config) - worker.warm_up_model() + worker.initialize_cache( + num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, + num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) # Randomly initialize the cache. gpu_cache = worker.cache_engine.gpu_cache diff --git a/vllm/__init__.py b/vllm/__init__.py index d53e591bcb062..2c1fd40573240 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,13 +5,15 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM +from vllm.model_executor.models import ModelRegistry from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -__version__ = "0.4.0" +__version__ = "0.4.0.post1" __all__ = [ "LLM", + "ModelRegistry", "SamplingParams", "RequestOutput", "CompletionOutput", diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a7e0ab92c7668..a03cf2dd7a6fa 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -81,5 +81,6 @@ class AttentionImpl(ABC): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + kv_scale: float, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e50d52377b8e0..4e0d9d1418b32 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -156,6 +156,7 @@ class FlashAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, + kv_scale: float, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -184,7 +185,8 @@ class FlashAttentionImpl(AttentionImpl): PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype) + attn_metadata.kv_cache_dtype, + kv_scale) if attn_metadata.is_prompt: # Prompt run. @@ -207,6 +209,9 @@ class FlashAttentionImpl(AttentionImpl): ) else: # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. output = PagedAttention.forward_prefix( query, key, @@ -233,6 +238,7 @@ class FlashAttentionImpl(AttentionImpl): self.num_kv_heads, self.scale, self.alibi_slopes, + kv_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py new file mode 100644 index 0000000000000..6019d917b4494 --- /dev/null +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -0,0 +1,348 @@ +"""Attention layer ROCm GPUs.""" +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class ROCmFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": + return ROCmFlashAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + + # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seqlen ----------------------| + # |- subquery_len -| + + # WARNING(sang): context_len has different definition depending on if it is + # prefill vs decoding. When it is prefill, it doesn't include new tokens. + # When it is for decoding, it includes a new token. + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # Maximum prompt length in the batch. + max_prompt_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9 + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. + self.use_triton_flash_attn = (os.environ.get( + "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) + if self.use_naive_attn: + # AMD Radeon 7900 series (gfx1100) currently does not support + # xFormers nor FlashAttention. As a temporary workaround, we use + # naive PyTorch implementation of attention. + self.attn_fuc = _naive_attention() + logger.debug("Using naive attention in ROCmBackend") + elif self.use_triton_flash_attn: + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.attn_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + else: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype, + kv_scale, + ) + + if attn_metadata.is_prompt: + # Prompt run. + if kv_cache is None or attn_metadata.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + if self.use_naive_attn or self.use_triton_flash_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + if self.use_naive_attn: + output = self.attn_fuc( + query, + key, + value, + attn_metadata.prompt_lens, + self.scale, + ) + else: + output, _ = self.attn_func( + query, + key, + value, + None, + attn_metadata.seq_start_loc, + attn_metadata.seq_start_loc, + attn_metadata.max_prompt_len, + attn_metadata.max_prompt_len, + True, + self.scale, + ) + else: + output = self.attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + ) + + else: + # prefix-enabled attention + output = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.subquery_start_loc, + attn_metadata.prompt_lens_tensor, + attn_metadata.context_lens, + attn_metadata.max_subquery_len, + self.alibi_slopes, + ) + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + +def _naive_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + prompt_lens: List[int], + scale: float, +) -> torch.Tensor: + num_tokens = query.shape[0] + output = torch.empty_like(query) + start = 0 + for _, prompt_len in enumerate(prompt_lens): + end = start + prompt_len + out = _naive_masked_attention( + query[None, start:end], + key[None, start:end], + value[None, start:end], + scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out) + start += prompt_len + + # Using view got RuntimeError: view size is not compatible + # with input tensor's size and stride (at least one + # dimension spans across two contiguous subspaces). + # Use reshape instead. + return output.reshape(num_tokens, -1) + + +def _naive_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +) -> torch.Tensor: + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py new file mode 100644 index 0000000000000..9706e1910cb79 --- /dev/null +++ b/vllm/attention/backends/torch_sdpa.py @@ -0,0 +1,256 @@ +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + + +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": + return TorchSDPAMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + prompt_lens: Optional[List[int]] + prompt_lens_tensor: Optional[torch.Tensor] + num_prompt_tokens: int + num_generation_tokens: int + + max_subquery_len: Optional[int] = None + max_prompt_len: Optional[int] = None + subquery_start_loc: Optional[torch.Tensor] = None + seq_start_loc: Optional[torch.Tensor] = None + use_cuda_graph: bool = False + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + +class TorchSDPABackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + if alibi_slopes is not None: + assert len(alibi_slopes) == num_heads + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: TorchSDPAMetadata, + kv_scale: float, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype, + kv_scale) + + if attn_metadata.is_prompt: + if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.prompt_lens) # type: ignore + elif self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.prompt_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = [None] * len(attn_metadata.prompt_lens) + attn_metadata.attn_bias = att_masks + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(attn_metadata.prompt_lens, + attn_metadata.attn_bias): + end = start + prompt_len + sub_out = scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=not self.need_mask, + scale=self.scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + else: + # prefix-enabled attention + raise RuntimeError( + "Torch SDPA backend doesn't support prefix decoding.") + + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + prompt_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for prompt_len in prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, prompt_len, prompt_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + prompt_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for prompt_len in prompt_lens: + tensor = torch.full( + (1, prompt_len, prompt_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 8e510f975059e..05b68bba5e6eb 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,5 +1,4 @@ """Attention layer with xFormers and PagedAttention.""" -import importlib from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type @@ -14,7 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger -from vllm.utils import is_hip logger = init_logger(__name__) @@ -166,11 +164,6 @@ class XFormersImpl(AttentionImpl): f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - # AMD Radeon 7900 series (gfx1100) currently does not support xFormers - # nor FlashAttention. As a temporary workaround, we use naive PyTorch - # implementation of attention. - self.use_naive_attention = _check_use_naive_attention() - def forward( self, query: torch.Tensor, @@ -178,6 +171,7 @@ class XFormersImpl(AttentionImpl): value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: XFormersMetadata, + kv_scale: float, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -205,7 +199,8 @@ class XFormersImpl(AttentionImpl): PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype) + attn_metadata.kv_cache_dtype, + kv_scale) if attn_metadata.is_prompt: # Prompt run. @@ -231,34 +226,13 @@ class XFormersImpl(AttentionImpl): self.num_queries_per_kv, value.shape[-1]) - if self.use_naive_attention: - output = torch.empty_like(query) - start = 0 - for _, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len - out = _naive_masked_attention( - query[None, start:end], - key[None, start:end], - value[None, start:end], - self.num_heads, - self.num_kv_heads, - self.head_size, - self.scale, - ) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out) - start += prompt_len - - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. - return output.reshape(num_tokens, hidden_size) - output = self._run_memory_efficient_xformers_forward( query, key, value, attn_metadata) else: # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. output = PagedAttention.forward_prefix( query, key, @@ -285,6 +259,7 @@ class XFormersImpl(AttentionImpl): self.num_kv_heads, self.scale, self.alibi_slopes, + kv_scale, ) # Reshape the output tensor. @@ -323,8 +298,6 @@ class XFormersImpl(AttentionImpl): self.alibi_slopes, self.num_kv_heads, query.dtype, attn_metadata.prompt_lens) - op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( - is_hip()) else None # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -338,8 +311,7 @@ class XFormersImpl(AttentionImpl): value, attn_bias=attn_metadata.attn_bias[0], p=0.0, - scale=self.scale, - op=op) + scale=self.scale) return out.view_as(query) @@ -357,8 +329,7 @@ class XFormersImpl(AttentionImpl): value[None, start:end], attn_bias=attn_metadata.attn_bias[i], p=0.0, - scale=self.scale, - op=op) + scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.squeeze(0)) start += prompt_len @@ -399,42 +370,3 @@ def _make_alibi_bias( attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) return attn_biases - - -def _check_use_naive_attention() -> bool: - if not is_hip(): - return False - # For ROCm, check whether flash attention is installed or not. - use_naive_attention = importlib.util.find_spec("flash_attn") is None - if use_naive_attention: - logger.warning("flash_attn is not installed. Using naive attention. " - "This will take significantly more GPU memory.") - return True - return False - - -def _naive_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - scale: float, -) -> torch.Tensor: - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - seq_len, _, _ = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 2e0aa18e52427..9856654fc5f94 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -42,5 +42,7 @@ class Attention(nn.Module): value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, + kv_scale: float = 1.0, ) -> torch.Tensor: - return self.impl.forward(query, key, value, kv_cache, attn_metadata) + return self.impl.forward(query, key, value, kv_cache, attn_metadata, + kv_scale) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 5901af4f0a02f..256bffdf032eb 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -73,6 +73,7 @@ class PagedAttention: value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, + kv_scale: float, ) -> None: cache_ops.reshape_and_cache( key, @@ -81,6 +82,7 @@ class PagedAttention: value_cache, slot_mapping.flatten(), kv_cache_dtype, + kv_scale, ) @staticmethod @@ -95,6 +97,7 @@ class PagedAttention: num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + kv_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) @@ -126,6 +129,7 @@ class PagedAttention: max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) else: # Run PagedAttention V2. @@ -157,6 +161,7 @@ class PagedAttention: max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale, ) return output diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py new file mode 100644 index 0000000000000..b86e845020b07 --- /dev/null +++ b/vllm/attention/ops/triton_flash_attention.py @@ -0,0 +1,809 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + hq, + hk, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + is_mqa = hq != hk + off_h_k = off_h_q % hk if is_mqa else off_h_q + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * hq + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] >= + out_mask_boundary[None, :]) + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + # TODO: Fix assert to check head size <=256 once supported + assert head_size <= 128 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + hq=nheads_q, + hk=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 1fd2d482e729c..20e5c2cda72b1 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,3 +1,4 @@ +import enum from functools import lru_cache from typing import Type @@ -5,50 +6,81 @@ import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.utils import is_hip +from vllm.utils import is_cpu, is_hip, is_tpu logger = init_logger(__name__) +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + PALLAS = enum.auto() + + @lru_cache(maxsize=None) def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - if True: - logger.info("Using PallasAttention backend.") - from vllm.attention.backends.pallas import ( # noqa: F401 - PallasAttentionBackend) - return PallasAttentionBackend - elif _can_use_flash_attn(dtype): + backend = _which_attn_to_use(dtype) + if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend - else: + elif backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend) return XFormersBackend + elif backend == _Backend.ROCM_FLASH: + logger.info("Using ROCmFlashAttention backend.") + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend + elif backend == _Backend.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + from vllm.attention.backends.torch_sdpa import TorchSDPABackend + return TorchSDPABackend + elif backend == _Backend.PALLAS: + logger.info("Using PallasAttention backend.") + from vllm.attention.backends.pallas import ( # noqa: F401 + PallasAttentionBackend) + return PallasAttentionBackend + else: + raise ValueError("Invalid attention backend.") -def _can_use_flash_attn(dtype: torch.dtype) -> bool: +def _which_attn_to_use(dtype: torch.dtype) -> _Backend: + """Returns which flash attention backend to use.""" + if is_tpu(): + return _Backend.PALLAS + if is_cpu(): + return _Backend.TORCH_SDPA + if is_hip(): # AMD GPUs. - logger.info("Cannot use FlashAttention backend for AMD GPUs.") - return False + if torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_atten is not supported on NAVI GPUs.") + return _Backend.ROCM_FLASH + + # NVIDIA GPUs. if torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " "GPUs.") - return False + return _Backend.XFORMERS + if dtype not in (torch.float16, torch.bfloat16): logger.info("Cannot use FlashAttention backend for dtype other than " "torch.float16 or torch.bfloat16.") - return False + return _Backend.XFORMERS try: import flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use FlashAttention because the package is not found. " - "Please install it for better performance.") - return False - return True + "Cannot use FlashAttention backend because the flash_attn package " + "is not found. Please install it for better performance.") + return _Backend.XFORMERS + return _Backend.FLASH_ATTN diff --git a/vllm/config.py b/vllm/config.py index 8dd670a4eccf8..0110a9235c301 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,7 @@ import enum import json import os -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import TYPE_CHECKING, ClassVar, Optional, Union import torch @@ -10,7 +10,7 @@ from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_hip, +from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, is_neuron, is_tpu) if TYPE_CHECKING: @@ -60,6 +60,11 @@ class ModelConfig: output). If None, will be derived from the model. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. + quantization_param_path: Path to JSON file containing scaling factors. + Used to load KV cache scaling factors into the model when KV cache + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the + model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. @@ -83,6 +88,7 @@ class ModelConfig: tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_logprobs: int = 5, @@ -98,6 +104,7 @@ class ModelConfig: self.code_revision = code_revision self.tokenizer_revision = tokenizer_revision self.quantization = quantization + self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs @@ -172,26 +179,28 @@ class ModelConfig: self.quantization = self.quantization.lower() # Parse quantization method from the HF model config, if available. - hf_quant_config = getattr(self.hf_config, "quantization_config", None) - if hf_quant_config is not None: - hf_quant_method = str(hf_quant_config["quant_method"]).lower() + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" + or quant_cfg.get("is_marlin_format", False)) - # If the GPTQ model is serialized in marlin format, use marlin. - if (hf_quant_method == "gptq" - and "is_marlin_format" in hf_quant_config - and hf_quant_config["is_marlin_format"]): + # Use marlin if the GPTQ model is serialized in marlin format. + if quant_method == "gptq" and is_format_marlin: logger.info("The model is serialized in Marlin format. " "Using Marlin kernel.") - hf_quant_method = "marlin" + quant_method = "marlin" if self.quantization == "gptq": - self.quantization = hf_quant_method + self.quantization = quant_method if self.quantization is None: - self.quantization = hf_quant_method - elif self.quantization != hf_quant_method: + self.quantization = quant_method + elif self.quantization != quant_method: raise ValueError( "Quantization method specified in the model config " - f"({hf_quant_method}) does not match the quantization " + f"({quant_method}) does not match the quantization " f"method specified in the `quantization` argument " f"({self.quantization}).") @@ -325,7 +334,7 @@ class CacheConfig: vLLM execution. swap_space: Size of the CPU swap space per GPU (in GiB). cache_dtype: Data type for kv cache storage. - forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the + num_gpu_blocks_override: Number of GPU blocks to use. This overrides the profiled num_gpu_blocks if specified. Does nothing if None. """ @@ -335,14 +344,14 @@ class CacheConfig: gpu_memory_utilization: float, swap_space: int, cache_dtype: str, - forced_num_gpu_blocks: Optional[int] = None, + num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB - self.forced_num_gpu_blocks = forced_num_gpu_blocks + self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching @@ -367,21 +376,20 @@ class CacheConfig: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype == "fp8_e5m2": - if is_hip(): - raise NotImplementedError( - "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") - nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"): - raise ValueError( - "FP8 is not supported when cuda version is lower than 11.8." - ) + elif self.cache_dtype == "fp8": + if not is_hip(): + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version < Version("11.8"): + raise ValueError( + "FP8 is not supported when cuda version is" + "lower than 11.8.") logger.info( - "Using fp8_e5m2 data type to store kv cache. It reduces " - "the GPU memory footprint and boosts the performance. " - "But it may cause slight accuracy drop. " - "Currently we only support fp8 without scaling factors and " - "make e5m2 as a default format.") + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "But it may cause slight accuracy drop without scaling " + "factors. FP8_E5M2 (without scaling) is only supported on " + "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 " + "is instead supported for common inference criteria.") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -531,9 +539,13 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). + use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. + num_lookahead_slots: The number of slots to allocate per sequence per + step, beyond the known token ids. This is used in speculative + decoding to store KV activations of tokens which may or may not be + accepted. delay_factor: Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt. - use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. """ @@ -544,6 +556,7 @@ class SchedulerConfig: max_num_seqs: int, max_model_len: int, use_v2_block_manager: bool = False, + num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, ) -> None: @@ -555,13 +568,16 @@ class SchedulerConfig: self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.delay_factor = delay_factor self.use_v2_block_manager = use_v2_block_manager + self.num_lookahead_slots = num_lookahead_slots + self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill + self._verify_args() def _verify_args(self) -> None: - if self.max_num_batched_tokens < self.max_model_len: + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " @@ -569,12 +585,19 @@ class SchedulerConfig: "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") + if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).") + if self.num_lookahead_slots < 0: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be greater than or " + "equal to 0.") + class DeviceConfig: @@ -585,6 +608,8 @@ class DeviceConfig: self.device_type = "neuron" elif is_tpu(): self.device_type = "tpu" + elif is_cpu(): + self.device_type = "cpu" else: # We don't call torch.cuda.is_available() here to # avoid initializing CUDA before workers are forked @@ -604,6 +629,159 @@ class DeviceConfig: self.device = torch.device(self.device_type) +class SpeculativeConfig: + """Configuration for speculative decoding. + + The configuration is currently specialized to draft-model speculative + decoding with top-1 proposals. + """ + + @staticmethod + def maybe_create_spec_config( + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + target_dtype: str, + speculative_model: Optional[str], + num_speculative_tokens: Optional[int], + ) -> Optional["SpeculativeConfig"]: + """Create a SpeculativeConfig if possible, else return None. + + This function attempts to create a SpeculativeConfig object based on the + provided parameters. If the necessary conditions are met, it returns an + instance of SpeculativeConfig. Otherwise, it returns None. + + Args: + target_model_config (ModelConfig): The configuration of the target + model. + target_parallel_config (ParallelConfig): The parallel configuration + for the target model. + target_dtype (str): The data type used for the target model. + speculative_model (Optional[str]): The name of the speculative + model, if provided. + num_speculative_tokens (Optional[int]): The number of speculative + tokens, if provided. + + Returns: + Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if + the necessary conditions are met, else None. + """ + + if (speculative_model is None and num_speculative_tokens is None): + return None + + if speculative_model is not None and num_speculative_tokens is None: + raise ValueError( + "Expected both speculative_model and " + "num_speculative_tokens to be provided, but found " + f"{speculative_model=} and {num_speculative_tokens=}.") + + # TODO: The user should be able to specify revision/quantization/max + # model len for the draft model. It is not currently supported. + draft_revision = None + draft_code_revision = None + draft_quantization = None + draft_max_model_len = None + + draft_model_config = ModelConfig( + model=speculative_model, + tokenizer=target_model_config.tokenizer, + tokenizer_mode=target_model_config.tokenizer_mode, + trust_remote_code=target_model_config.trust_remote_code, + download_dir=target_model_config.download_dir, + load_format=target_model_config.load_format, + dtype=target_model_config.dtype, + seed=target_model_config.seed, + revision=draft_revision, + code_revision=draft_code_revision, + tokenizer_revision=target_model_config.tokenizer_revision, + max_model_len=draft_max_model_len, + quantization=draft_quantization, + enforce_eager=target_model_config.enforce_eager, + max_context_len_to_capture=target_model_config. + max_context_len_to_capture, + max_logprobs=target_model_config.max_logprobs, + ) + + draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + target_parallel_config)) + + return SpeculativeConfig( + draft_model_config, + draft_parallel_config, + num_speculative_tokens, + ) + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config. In the future the + draft worker can have a different parallel strategy, e.g. TP=1. + """ + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config. + pipeline_parallel_size, + tensor_parallel_size=target_parallel_config.tensor_parallel_size, + worker_use_ray=target_parallel_config.worker_use_ray, + max_parallel_loading_workers=target_parallel_config. + max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config. + disable_custom_all_reduce, + tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, + ray_workers_use_nsight=target_parallel_config. + ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + ) + + return draft_parallel_config + + def __init__( + self, + draft_model_config: ModelConfig, + draft_parallel_config: ParallelConfig, + num_speculative_tokens: int, + ): + """Create a SpeculativeConfig object. + + Args: + draft_model_config: ModelConfig for the draft model. + draft_parallel_config: ParallelConfig for the draft model. + num_speculative_tokens: The number of tokens to sample from the + draft model before scoring with the target model. + """ + self.draft_model_config = draft_model_config + self.draft_parallel_config = draft_parallel_config + self.num_speculative_tokens = num_speculative_tokens + + self._verify_args() + + def _verify_args(self) -> None: + if self.num_speculative_tokens <= 0: + raise ValueError("Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens}).") + + if self.draft_model_config: + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config) + + @property + def num_lookahead_slots(self) -> int: + """The number of additional slots the scheduler should allocate per + step, in addition to the slots allocated for each known token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + def __repr__(self) -> str: + draft_model = self.draft_model_config.model + num_spec_tokens = self.num_speculative_tokens + return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" + + @dataclass class LoRAConfig: max_lora_rank: int @@ -825,3 +1003,36 @@ def _get_and_verify_max_len( "to incorrect model outputs or CUDA errors. Make sure the " "value is correct and within the model context size.") return int(max_model_len) + + +@dataclass(frozen=True) +class EngineConfig: + """Dataclass which contains all engine-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + model_config: ModelConfig + cache_config: CacheConfig + parallel_config: ParallelConfig + scheduler_config: SchedulerConfig + device_config: DeviceConfig + lora_config: Optional[LoRAConfig] + vision_language_config: Optional[VisionLanguageConfig] + speculative_config: Optional[SpeculativeConfig] + + def __post_init__(self): + """Verify configs are valid & consistent with each other. + """ + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/core/block/__init__.py b/vllm/core/block/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 793c6698633af..ba061bbc4fbcb 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -85,7 +85,9 @@ class BlockTable: device=device) self._num_full_slots = len(token_ids) - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids(self, + token_ids: List[int], + num_lookahead_slots: int = 0) -> None: """Appends a sequence of token IDs to the existing blocks in the BlockTable. @@ -102,14 +104,13 @@ class BlockTable: token_ids (List[int]): The sequence of token IDs to be appended. """ assert self._is_allocated + assert token_ids, "can't append empty token ids" - self.ensure_num_empty_slots(num_empty_slots=len(token_ids)) + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + + num_lookahead_slots) blocks = self._blocks[self._num_full_slots // self._block_size:] - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - token_blocks = [token_ids[:first_chunk_size]] + chunk_list( - token_ids[first_chunk_size:], self._block_size) + token_blocks = self._chunk_token_blocks_for_append(token_ids) for block, token_block in zip(blocks, token_blocks): block.append_token_ids(token_block) @@ -195,6 +196,25 @@ class BlockTable: assert self._is_allocated return [block.block_id for block in self._blocks] + def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + """Get the number of "unseen" tokens in the sequence. + + Unseen tokens are tokens in the sequence corresponding to this block + table, but are not yet appended to this block table. + + Args: + sequence_token_ids (List[int]): The list of token ids in the + sequence. + + Returns: + List[int]: The postfix of sequence_token_ids that has not yet been + appended to the block table. + """ + + # Since the block table is append-only, the unseen token ids are the + # ones after the appended ones. + return sequence_token_ids[self.num_full_slots:] + def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], token_ids: List[int], device: Device) -> List[Block]: @@ -243,3 +263,29 @@ class BlockTable: int: The total number of tokens currently stored in the BlockTable. """ return self._num_full_slots + + def get_num_blocks_touched_by_append_slots( + self, token_ids: List[int], num_lookahead_slots: int) -> int: + """Determine how many blocks will be "touched" by appending the token + ids. + + This is required for the scheduler to determine whether a sequence can + continue generation, or if it must be preempted. + """ + + all_token_ids = token_ids + [-1] * num_lookahead_slots + token_blocks = self._chunk_token_blocks_for_append(all_token_ids) + return len(token_blocks) + + def _chunk_token_blocks_for_append( + self, token_ids: List[int]) -> List[List[int]]: + """Split the token ids into block-sized chunks so they can be easily + appended to blocks. The first such "token block" may have less token ids + than the block size, since the last allocated block may be partially + full. + """ + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + token_blocks = [token_ids[:first_chunk_size]] + chunk_list( + token_ids[first_chunk_size:], self._block_size) + return token_blocks diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 160a86556f031..e7e3b4dc1e9b4 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from itertools import count, takewhile from os.path import commonprefix -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set from vllm.block import BlockTable, PhysicalTokenBlock from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor @@ -292,7 +292,12 @@ class BlockSpaceManagerV1(BlockSpaceManager): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() - def can_append_slot(self, seq_group: SequenceGroup) -> bool: + def can_append_slots(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> bool: + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerV1" + # Simple heuristic: If there is at least one free block # for each sequence, we can append. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() @@ -323,7 +328,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): self, seq: Sequence, ) -> bool: - token_ids_len = len(seq.data.get_token_ids()) + token_ids_len = seq.data.get_len() return token_ids_len > 0 and token_ids_len % seq.block_size == 0 def _maybe_promote_last_block( @@ -364,10 +369,11 @@ class BlockSpaceManagerV1(BlockSpaceManager): assert new_block.ref_count == 1 return new_block - def append_slot( + def append_slots( self, seq: Sequence, - ) -> Optional[Tuple[int, int]]: + num_lookahead_slots: int = 0, + ) -> Dict[int, List[int]]: """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] @@ -386,7 +392,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): # Allocate a new physical block. new_block = self._allocate_last_physical_block(seq) block_table.append(new_block) - return None + return {} # We want to append the token to the last physical block. last_block = block_table[-1] @@ -399,7 +405,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): maybe_new_block = self._maybe_promote_last_block( seq, last_block) block_table[-1] = maybe_new_block - return None + return {} else: # The last block is shared with other sequences. # Copy on Write: Allocate a new block and copy the tokens. @@ -407,7 +413,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): block_table[-1] = new_block self.gpu_allocator.free(last_block) - return last_block.block_number, new_block.block_number + return {last_block.block_number: [new_block.block_number]} def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. @@ -433,7 +439,11 @@ class BlockSpaceManagerV1(BlockSpaceManager): blocks.update(self.block_tables[seq.seq_id]) return list(blocks) - def can_swap_in(self, seq_group: SequenceGroup) -> bool: + def can_swap_in(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> bool: + assert (num_lookahead_slots == 0 + ), "BlockSpaceManagerV1 does not support lookahead allocation" blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) num_free_blocks = self.gpu_allocator.get_num_free_blocks() @@ -443,7 +453,12 @@ class BlockSpaceManagerV1(BlockSpaceManager): num_required_blocks = len(blocks) + num_swapped_seqs return num_free_blocks - num_required_blocks >= self.watermark_blocks - def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_in(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> Dict[int, int]: + assert (num_lookahead_slots == 0 + ), "BlockSpaceManagerV1 does not support lookahead allocation" + # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 37c70073b663b..813e71ad883b2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,5 +1,5 @@ """A block manager that manages token blocks.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator @@ -21,6 +21,24 @@ class BlockSpaceManagerV2(BlockSpaceManager): sliding-window are not feature complete. This class implements the design described in https://github.com/vllm-project/vllm/pull/3492. + Lookahead slots + The block manager has the notion of a "lookahead slot". These are slots + in the KV cache that are allocated for a sequence. Unlike the other + allocated slots, the content of these slots is undefined -- the worker + may use the memory allocations in any way. + + In practice, a worker could use these lookahead slots to run multiple + forward passes for a single scheduler invocation. Each successive + forward pass would write KV activations to the corresponding lookahead + slot. This allows low inter-token latency use-cases, where the overhead + of continuous batching scheduling is amortized over >1 generated tokens. + + Speculative decoding uses lookahead slots to store KV activations of + proposal tokens. + + See https://github.com/vllm-project/vllm/pull/3250 for more information + on lookahead scheduling. + Args: block_size (int): The size of each memory block. num_gpu_blocks (int): The number of memory blocks allocated on GPU. @@ -116,35 +134,51 @@ class BlockSpaceManagerV2(BlockSpaceManager): for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() - def can_append_slot(self, seq_group: SequenceGroup) -> bool: - # Simple heuristic: If there is at least one free block - # for each sequence, we can append. + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + """Determine if there is enough space in the GPU KV cache to continue + generation of the specified sequence group. + + We use a worst-case heuristic: assume each touched block will require a + new allocation (either via CoW or new block). We can append slots if the + number of touched blocks is less than the number of free blocks. + + "Lookahead slots" are slots that are allocated in addition to the slots + for known tokens. The contents of the lookahead slots are not defined. + This is used by speculative decoding when speculating future tokens. + """ + + num_touched_blocks = 0 + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + block_table = self.block_tables[seq.seq_id] + + num_touched_blocks += ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids( + seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( Device.GPU) - num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) - return num_seqs <= num_free_gpu_blocks + return num_touched_blocks <= num_free_gpu_blocks - def append_slot( + def append_slots( self, seq: Sequence, - ) -> Optional[Tuple[int, int]]: + num_lookahead_slots: int, + ) -> Dict[int, List[int]]: block_table = self.block_tables[seq.seq_id] - # Get unseen token ids. - num_full_slots = block_table.num_full_slots - unseen_token_ids = seq.get_token_ids()[num_full_slots:] - assert unseen_token_ids + block_table.append_token_ids( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + ) - block_table.append_token_ids(unseen_token_ids) - - # Return any copy-on-writes. - _ = self.block_allocator.clear_copy_on_writes() - - # TODO extend append_slot interface to append_slots - # @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250 - - return None + # Return any new copy-on-writes. + new_cows = self.block_allocator.clear_copy_on_writes() + return new_cows def free(self, seq: Sequence) -> None: if seq.seq_id not in self.block_tables: @@ -191,10 +225,12 @@ class BlockSpaceManagerV2(BlockSpaceManager): src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.fork() - def can_swap_in(self, seq_group: SequenceGroup) -> bool: + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: return False - def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> Dict[int, int]: raise NotImplementedError def can_swap_out(self, seq_group: SequenceGroup) -> bool: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 48524de0df8ea..711536bcc97be 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,6 +1,6 @@ import enum from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Dict, List from vllm.sequence import Sequence, SequenceGroup @@ -44,14 +44,16 @@ class BlockSpaceManager(ABC): pass @abstractmethod - def can_append_slot(self, seq_group: SequenceGroup) -> bool: + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: pass @abstractmethod - def append_slot( + def append_slots( self, seq: Sequence, - ) -> Optional[Tuple[int, int]]: + num_lookahead_slots: int, + ) -> Dict[int, List[int]]: pass @abstractmethod @@ -59,11 +61,13 @@ class BlockSpaceManager(ABC): pass @abstractmethod - def can_swap_in(self, seq_group: SequenceGroup) -> bool: + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: pass @abstractmethod - def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> Dict[int, int]: pass @abstractmethod diff --git a/vllm/core/policy.py b/vllm/core/policy.py index 2e9ebbda54412..a4463ac0f340e 100644 --- a/vllm/core/policy.py +++ b/vllm/core/policy.py @@ -38,9 +38,7 @@ class FCFS(Policy): class PolicyFactory: - _POLICY_REGISTRY = { - 'fcfs': FCFS, - } + _POLICY_REGISTRY = {'fcfs': FCFS} @classmethod def get_policy(cls, policy_name: str, **kwargs) -> Policy: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 04e8056aab544..0ae53f9374960 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,16 +1,17 @@ import enum import time from collections import deque -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.core.policy import PolicyFactory +from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) +from vllm.utils import merge_dicts logger = init_logger(__name__) @@ -28,9 +29,67 @@ class PreemptionMode(enum.Enum): RECOMPUTE = enum.auto() -# seq_group: SequenceGroup to schedule. -# token_chunk_size: The number of prefill tokens to be processed in the next -# step. +@dataclass +class SchedulingBudget: + """The available slots for scheduling. + + TODO(sang): Right now, the budget is request_id-aware meaning it can ignore + budget update from the same request_id. It is because in normal scheduling + path, we update RUNNING num_seqs ahead of time, meaning it could be + updated more than once when scheduling RUNNING requests. Since this won't + happen if we only have chunked prefill scheduling, we can remove this + feature from the API when chunked prefill is enabled by default. + """ + token_budget: int + max_num_seqs: int + _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) + _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) + _num_batched_tokens: int = 0 + _num_curr_seqs: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): + assert num_new_tokens != 0 + assert num_new_seqs != 0 + return (self.num_batched_tokens + num_new_tokens <= self.token_budget + and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) + + def remaining_token_budget(self): + return self.token_budget - self.num_batched_tokens + + def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + if req_id in self._requeset_ids_num_batched_tokens: + return + + self._requeset_ids_num_batched_tokens.add(req_id) + self._num_batched_tokens += num_batched_tokens + + def subtract_num_batched_tokens(self, req_id: str, + num_batched_tokens: int): + if req_id in self._requeset_ids_num_batched_tokens: + self._requeset_ids_num_batched_tokens.remove(req_id) + self._num_batched_tokens -= num_batched_tokens + + def add_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._requeset_ids_num_curr_seqs: + return + + self._requeset_ids_num_curr_seqs.add(req_id) + self._num_curr_seqs += num_curr_seqs + + def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._requeset_ids_num_curr_seqs: + self._requeset_ids_num_curr_seqs.remove(req_id) + self._num_curr_seqs -= num_curr_seqs + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_seqs(self): + return self._num_curr_seqs + + @dataclass class ScheduledSequenceGroup: # A sequence group that's scheduled. @@ -41,51 +100,29 @@ class ScheduledSequenceGroup: token_chunk_size: int +@dataclass class SchedulerOutputs: + """The scheduling decision made from a scheduler.""" + # Scheduled sequence groups. + scheduled_seq_groups: Iterable[ScheduledSequenceGroup] + # Number of prefill groups scheduled. + num_prefill_groups: int + # Total number of batched tokens. + num_batched_tokens: int + # Blocks to swap in. Dict of CPU -> GPU block number. + blocks_to_swap_in: Dict[int, int] + # Blocks to swap out. Dict of GPU -> CPU block number. + blocks_to_swap_out: Dict[int, int] + # Blocks to copy. Source to a list of dest blocks. + blocks_to_copy: Dict[int, List[int]] + # Sequence groups that are going to be ignored. + ignored_seq_groups: List[SequenceGroup] + # The number of slots for lookahead decoding. + num_lookahead_slots: int - def __init__( - self, - scheduled_seq_groups: Iterable[ScheduledSequenceGroup], - prompt_run: bool, - num_batched_tokens: int, - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ignored_seq_groups: List[SequenceGroup], - ) -> None: - """A list of sequence groups to be scheduled as a single batch. - - Args: - scheduled_seq_groups: A tuple of scheduled sequence group and its - token chunk size. - prompt_run: True if all sequence groups are in prefill phase. - If False, all sequence groups are in decoding phase. - num_batched_tokens: Total number of batched tokens. - blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block - number. - blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block - number. - blocks_to_copy: Blocks to copy. Source to a list of dest blocks. - ignored_seq_groups: Sequence groups that are going to be ignored. - """ - # A tuple of scheduled sequence group and its chunk size. - self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups - # True if all sequence groups are in prefill phase. If False, all - # sequence groups are in decoding phase. - self.prompt_run: bool = prompt_run - # Total number of batched tokens. - self.num_batched_tokens: int = num_batched_tokens - # Blocks to swap in. Dict of CPU -> GPU block number. - self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in - # Blocks to swap out. Dict of GPU -> CPU block number. - self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out - # Blocks to copy. Source to a list of dest blocks. - self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy - # Sequence groups that are going to be ignored. - self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups - + def __post_init__(self): # Swap in and swap out should never happen at the same time. - assert not (blocks_to_swap_in and blocks_to_swap_out) + assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) self.num_loras: int = len(self.lora_requests) if self.num_loras > 0: @@ -106,6 +143,94 @@ class SchedulerOutputs: return {g.seq_group.lora_request for g in self.scheduled_seq_groups} +@dataclass +class SchedulerRunningOutputs: + """The requests that are scheduled from a running queue. + + Could contain prefill (prefill that's chunked) or decodes. If there's not + enough memory, it can be preempted (for recompute) or swapped out. + """ + # Selected sequences that are running and in a decoding phase. + decode_seq_groups: List[SequenceGroup] + # Selected sequences that are running and in a prefill phase. + # I.e., it means the prefill has been chunked. + prefill_seq_groups: List[SequenceGroup] + # The preempted sequences. + preempted: List[SequenceGroup] + # Sequences that are swapped out. + swapped_out: List[SequenceGroup] + # The blocks to swap out. + blocks_to_swap_out: Dict[int, int] + # The blocks to copy. + blocks_to_copy: Dict[int, List[int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerRunningOutputs": + return SchedulerRunningOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out={}, + blocks_to_copy={}, + num_lookahead_slots=0, + ) + + +@dataclass +class SchedulerSwappedInOutputs: + """The requests that are scheduled from a swap queue. + + Could contain prefill (prefill that's chunked) or decodes. + """ + # Selected sequences that are going to be swapped in and is in a + # decoding phase. + decode_seq_groups: List[SequenceGroup] + # Selected sequences that are going to be swapped in and in a prefill + # phase. I.e., it means the prefill has been chunked. + prefill_seq_groups: List[SequenceGroup] + # The blocks to swap in. + blocks_to_swap_in: Dict[int, int] + # The blocks to copy. + blocks_to_copy: Dict[int, List[int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerSwappedInOutputs": + return SchedulerSwappedInOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + blocks_to_swap_in={}, + blocks_to_copy={}, + num_lookahead_slots=0, + ) + + +@dataclass +class SchedulerPrefillOutputs: + """The requests that are scheduled from a waiting queue. + + Could contain a fresh prefill requests or preempted requests that need + to be recomputed from scratch. + """ + # Selected sequences for prefill. + seq_groups: List[SequenceGroup] + # Ignored sequence groups. + ignored_seq_groups: List[SequenceGroup] + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerPrefillOutputs": + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=0, + ) + + class Scheduler: def __init__( @@ -121,11 +246,12 @@ class Scheduler: # LoRAs. This should be improved in the future. self.lora_config = lora_config - self.prompt_limit = min(self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) - - # Instantiate the scheduling policy. - self.policy = PolicyFactory.get_policy(policy_name="fcfs") + if self.scheduler_config.chunked_prefill_enabled: + self.prompt_limit = self.scheduler_config.max_model_len + else: + self.prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version="v2" if self.scheduler_config. @@ -140,10 +266,13 @@ class Scheduler: enable_caching=self.cache_config.enable_prefix_caching) # Sequence groups in the WAITING state. + # Contain new prefill or preempted requests. self.waiting: Deque[SequenceGroup] = deque() # Sequence groups in the RUNNING state. + # Contain decode requests. self.running: Deque[SequenceGroup] = deque() # Sequence groups in the SWAPPED state. + # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() # Time at previous scheduling step @@ -157,8 +286,14 @@ class Scheduler: def lora_enabled(self) -> bool: return bool(self.lora_config) + @property + def num_decoding_tokens_per_seq(self) -> int: + """The number of new tokens.""" + return 1 + def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. + logger.debug(f"add_seq_group {seq_group.request_id}") self.waiting.append(seq_group) def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: @@ -203,210 +338,542 @@ class Scheduler: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def _schedule(self) -> SchedulerOutputs: + def _schedule_running( + self, + running_queue: deque, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + policy: Policy, + enable_chunking: bool = False, + ) -> Tuple[deque, SchedulerRunningOutputs]: + """Schedule sequence groups that are running. + + Running queue should include decode and chunked prefill requests. + + Args: + running_queue: The queue that contains running requests (i.e., + decodes). The given arguments are NOT in-place modified. + budget: The scheduling budget. The argument is in-place updated + when any decodes are preempted. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any decodes are preempted. + policy: The sorting policy to sort running_queue. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + A tuple of remaining running queue (should be always 0) after + scheduling and SchedulerRunningOutputs. + """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} - # Fix the current time. - now = time.time() - - # Join waiting sequences if possible. - if not self.swapped: - ignored_seq_groups: List[SequenceGroup] = [] - scheduled: List[SequenceGroup] = [] - # The total number of sequences on the fly, including the - # requests in the generation phase. - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) - curr_loras = set( - seq_group.lora_int_id - for seq_group in self.running) if self.lora_enabled else None - - # Optimization: We do not sort the waiting queue since the preempted - # sequence groups are added to the front and the new sequence groups - # are added to the back. - leftover_waiting_sequences = deque() - num_batched_tokens = 0 - while self._passed_delay(now) and self.waiting: - seq_group = self.waiting[0] - waiting_seqs = seq_group.get_seqs( - status=SequenceStatus.WAITING) - assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") - # get_len includes output tokens if the request has been - # preempted. - num_prefill_tokens = waiting_seqs[0].get_len() - if num_prefill_tokens > self.prompt_limit: - logger.warning( - f"Input prompt ({num_prefill_tokens} tokens) is too " - f"long and exceeds limit of {self.prompt_limit}") - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - self.waiting.popleft() - continue - - # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate(seq_group) - if can_allocate == AllocStatus.LATER: - break - elif can_allocate == AllocStatus.NEVER: - logger.warning( - f"Input prompt ({num_prefill_tokens} tokens) is too " - f"long and exceeds the capacity of block_manager") - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - self.waiting.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_waiting_sequences.appendleft(seq_group) - self.waiting.popleft() - continue - - # If the number of batched tokens exceeds the limit, stop. - num_batched_tokens += num_prefill_tokens - if (num_batched_tokens > - self.scheduler_config.max_num_batched_tokens): - break - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): - break - - if lora_int_id > 0: - curr_loras.add(lora_int_id) - self.waiting.popleft() - self._allocate(seq_group) - self.running.append(seq_group) - num_curr_seqs += num_new_seqs - scheduled.append( - ScheduledSequenceGroup( - seq_group=seq_group, - token_chunk_size=num_prefill_tokens)) - self.waiting.extendleft(leftover_waiting_sequences) - - if scheduled or ignored_seq_groups: - self.prev_prompt = True - scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=scheduled, - prompt_run=True, - num_batched_tokens=num_batched_tokens, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ignored_seq_groups=ignored_seq_groups, - ) - return scheduler_outputs + decode_seq_groups: List[ScheduledSequenceGroup] = [] + prefill_seq_groups: List[ScheduledSequenceGroup] = [] + preempted: List[SequenceGroup] = [] + swapped_out: List[SequenceGroup] = [] # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. # In this case, the policy is responsible for deciding which sequence # groups to preempt. - self.running = self.policy.sort_by_priority(now, self.running) + now = time.time() + running_queue = policy.sort_by_priority(now, running_queue) - # Reserve new token slots for the running sequence groups. - running: Deque[SequenceGroup] = deque() - preempted: List[SequenceGroup] = [] - while self.running: - seq_group = self.running.popleft() - while not self.block_manager.can_append_slot(seq_group): - if self.running: + while running_queue: + seq_group = running_queue[0] + num_running_tokens = self._get_num_new_tokens( + seq_group, SequenceStatus.RUNNING, enable_chunking, budget) + + # We can have up to 1 running prefill at any given time in running + # queue, which means we can guarantee chunk size is at least 1. + assert num_running_tokens != 0 + num_running_seqs = seq_group.get_max_num_running_seqs() + + running_queue.popleft() + while not self._can_append_slots(seq_group): + budget.subtract_num_batched_tokens(seq_group.request_id, + num_running_tokens) + budget.subtract_num_seqs(seq_group.request_id, + num_running_seqs) + if curr_loras is not None and seq_group.lora_int_id > 0: + curr_loras.pop(seq_group.lora_int_id) + + if running_queue: # Preempt the lowest-priority sequence groups. - victim_seq_group = self.running.pop() - self._preempt(victim_seq_group, blocks_to_swap_out) - preempted.append(victim_seq_group) + victim_seq_group = running_queue.pop() + preempted_mode = self._preempt(victim_seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(victim_seq_group) + else: + swapped_out.append(victim_seq_group) else: # No other sequence groups can be preempted. # Preempt the current sequence group. - self._preempt(seq_group, blocks_to_swap_out) - preempted.append(seq_group) + preempted_mode = self._preempt(seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(seq_group) + else: + swapped_out.append(seq_group) break else: - # Append new slots to the sequence group. - self._append_slot(seq_group, blocks_to_copy) - running.append(seq_group) - self.running = running + logger.debug(f"append slot for {seq_group}") + self._append_slots(seq_group, blocks_to_copy) + is_prefill = seq_group.is_prefill() + if is_prefill: + prefill_seq_groups.append( + ScheduledSequenceGroup( + seq_group=seq_group, + token_chunk_size=num_running_tokens)) + else: + decode_seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=1)) + budget.add_num_batched_tokens(seq_group.request_id, + num_running_tokens) + budget.add_num_seqs(seq_group.request_id, num_running_seqs) + if curr_loras is not None and seq_group.lora_int_id > 0: + curr_loras.add(seq_group.lora_int_id) - # Swap in the sequence groups in the SWAPPED state if possible. - self.swapped = self.policy.sort_by_priority(now, self.swapped) - if not preempted: - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) - curr_loras = set( - seq_group.lora_int_id - for seq_group in self.running) if self.lora_enabled else None + # Make sure all queues are updated. + assert len(running_queue) == 0 - leftover_swapped = deque() - - while self.swapped: - seq_group = self.swapped[0] - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_swapped.appendleft(seq_group) - self.swapped.popleft() - continue - - # If the sequence group cannot be swapped in, stop. - if not self.block_manager.can_swap_in(seq_group): - break - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): - break - - if lora_int_id > 0: - curr_loras.add(lora_int_id) - self.swapped.popleft() - self._swap_in(seq_group, blocks_to_swap_in) - self._append_slot(seq_group, blocks_to_copy) - num_curr_seqs += num_new_seqs - self.running.append(seq_group) - - self.swapped.extendleft(leftover_swapped) - - # Each sequence in the generation phase only takes one token slot. - # Therefore, the number of batched tokens is equal to the number of - # sequences in the RUNNING state. - num_batched_tokens = sum( - seq_group.num_seqs(status=SequenceStatus.RUNNING) - for seq_group in self.running) - - scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=[ - ScheduledSequenceGroup(seq_group=running_group, - token_chunk_size=1) - for running_group in self.running - ], - prompt_run=False, - num_batched_tokens=num_batched_tokens, - blocks_to_swap_in=blocks_to_swap_in, + return running_queue, SchedulerRunningOutputs( + decode_seq_groups=decode_seq_groups, + prefill_seq_groups=prefill_seq_groups, + preempted=preempted, + swapped_out=swapped_out, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False)) + + def _schedule_swapped( + self, + swapped_queue: deque, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + policy: Policy, + enable_chunking: bool = False, + ) -> Tuple[deque, SchedulerSwappedInOutputs]: + """Schedule sequence groups that are swapped out. + + It schedules swapped requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + swapped_queue: The queue that contains swapped out requests. + The given arguments are NOT in-place modified. + budget: The scheduling budget. The argument is in-place updated + when any requests are swapped in. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are swapped in. + policy: The sorting policy to sort swapped_queue. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + A tuple of remaining swapped_queue after scheduling and + SchedulerSwappedInOutputs. + """ + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_in: Dict[int, int] = {} + blocks_to_copy: Dict[int, List[int]] = {} + decode_seq_groups: List[ScheduledSequenceGroup] = [] + prefill_seq_groups: List[ScheduledSequenceGroup] = [] + now = time.time() + swapped_queue = policy.sort_by_priority(now, swapped_queue) + + leftover_swapped = deque() + while swapped_queue: + seq_group = swapped_queue[0] + + # If the sequence group cannot be swapped in, stop. + if not self.block_manager.can_swap_in(seq_group): + break + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if (lora_int_id > 0 and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + swapped_queue.popleft() + continue + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.SWAPPED, + enable_chunking, budget) + + if (num_new_tokens == 0 + or not budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + if lora_int_id > 0 and curr_loras is not None: + curr_loras.add(lora_int_id) + swapped_queue.popleft() + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slots(seq_group, blocks_to_copy) + is_prefill = seq_group.is_prefill() + if is_prefill: + prefill_seq_groups.append( + ScheduledSequenceGroup(seq_group, + token_chunk_size=num_new_tokens)) + else: + assert num_new_tokens == 1 + decode_seq_groups.append( + ScheduledSequenceGroup(seq_group, token_chunk_size=1)) + budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + swapped_queue.extendleft(leftover_swapped) + + return swapped_queue, SchedulerSwappedInOutputs( + decode_seq_groups=decode_seq_groups, + prefill_seq_groups=prefill_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False)) + + def _schedule_prefills( + self, + waiting_queue: deque, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> Tuple[deque, SchedulerPrefillOutputs]: + """Schedule sequence groups that are in prefill stage. + + Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE + as a new prefill (that starts from beginning -> most recently generated + tokens). + + It schedules waiting requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + waiting_queue: The queue that contains prefill requests. + The given arguments are NOT in-place modified. + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are scheduled. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + A tuple of remaining waiting_queue after scheduling and + SchedulerSwappedInOutputs. + """ + ignored_seq_groups: List[SequenceGroup] = [] + seq_groups: List[SequenceGroup] = [] + # We don't sort waiting queue because we assume it is sorted. + # Copy the queue so that the input queue is not modified. + waiting_queue = deque([s for s in waiting_queue]) + + leftover_waiting_sequences = deque() + while self._passed_delay(time.time()) and waiting_queue: + seq_group = waiting_queue[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + enable_chunking, budget) + if not enable_chunking: + num_prompt_tokens = waiting_seqs[0].get_len() + assert num_new_tokens == num_prompt_tokens + + if num_new_tokens > self.prompt_limit: + logger.warning( + f"Input prompt ({num_new_tokens} tokens) is too long" + f" and exceeds limit of {self.prompt_limit}") + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate(seq_group) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + f"Input prompt ({num_new_tokens} tokens) is too long" + f" and exceeds the capacity of block_manager") + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_new_tokens == 0 + or not budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + waiting_queue.popleft() + self._allocate_and_set_running(seq_group, num_new_tokens) + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) + budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + # Queue requests that couldn't be scheduled. + waiting_queue.extendleft(leftover_waiting_sequences) + if len(seq_groups) > 0: + self.prev_prompt = True + + return waiting_queue, SchedulerPrefillOutputs( + seq_groups=seq_groups, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) + + def _schedule_default(self) -> SchedulerOutputs: + """Schedule queued requests. + + The current policy is designed to opimimize the throughput. First, + it batches as many prefill requests as possible. And it schedules + decodes. If there's a pressure on GPU memory, decode requests can + be swapped or preempted. + """ + # Include running requests to the budget. + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + # Make sure we include num running seqs before scheduling prefill, + # so that we don't schedule beyond max_num_seqs for prefill. + for seq_group in self.running: + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + remaining_waiting, prefills = (self.waiting, + SchedulerPrefillOutputs.create_empty()) + remaining_running, running_scheduled = ( + self.running, SchedulerRunningOutputs.create_empty()) + remaining_swapped, swapped_in = ( + self.swapped, SchedulerSwappedInOutputs.create_empty()) + + # If any requests are swapped, prioritized swapped requests. + if not self.swapped: + remaining_waiting, prefills = self._schedule_prefills( + self.waiting, budget, curr_loras, enable_chunking=False) + + fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") + # Don't schedule decodes if prefills are scheduled. + # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running + # only contains decode requests, not chunked prefills. + if len(prefills.seq_groups) == 0: + remaining_running, running_scheduled = self._schedule_running( + self.running, + budget, + curr_loras, + fcfs_policy, + enable_chunking=False) + + # If any sequence group is preempted, do not swap in any sequence + # group. because it means there's no slot for new running requests. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + remaining_swapped, swapped_in = self._schedule_swapped( + self.swapped, budget, curr_loras, fcfs_policy) + + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting = remaining_waiting + self.waiting.extendleft(running_scheduled.preempted) + # Update new running requests. + self.running = remaining_running + self.running.extend([s.seq_group for s in prefills.seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + # Update swapped requests. + self.swapped = remaining_swapped + self.swapped.extend(running_scheduled.swapped_out) + + # There should be no prefill from running queue because this policy + # doesn't allow chunked prefills. + assert len(running_scheduled.prefill_seq_groups) == 0 + assert len(swapped_in.prefill_seq_groups) == 0 + return SchedulerOutputs( + scheduled_seq_groups=(prefills.seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), + num_prefill_groups=len(prefills.seq_groups), + num_batched_tokens=budget.num_batched_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, + swapped_in.blocks_to_copy), + ignored_seq_groups=prefills.ignored_seq_groups, + num_lookahead_slots=(prefills.num_lookahead_slots + + running_scheduled.num_lookahead_slots + + swapped_in.num_lookahead_slots), + ) + + def _schedule_chunked_prefill(self): + """Schedule queued requests. + + Chunked prefill allows to chunk prefill requests, batch them together + with decode requests. This policy 1. schedule as many decoding requests + as possible. 2. schedule chunked prefill requests that are not + finished. 3. schedule swapped request. 4. schedule new prefill + requests. + + The policy can sustain the high GPU utilization because it can put + prefill and decodes requests to the same batch, while it improves + inter token latency because decodes requests don't need to blocked + by prefill requests. + """ + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + curr_loras = set() + + remaining_waiting, prefills = (self.waiting, + SchedulerPrefillOutputs.create_empty()) + remaining_running, running_scheduled = ( + self.running, SchedulerRunningOutputs.create_empty()) + remaining_swapped, swapped_in = ( + self.swapped, SchedulerSwappedInOutputs.create_empty()) + + # Decoding should be always scheduled first by fcfs. + fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") + remaining_running, running_scheduled = self._schedule_running( + self.running, + budget, + curr_loras, + fcfs_policy, + enable_chunking=True) + + # Schedule swapped out requests. + # If preemption happens, it means we don't have space for swap-in. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + remaining_swapped, swapped_in = self._schedule_swapped( + self.swapped, budget, curr_loras, fcfs_policy) + + # Schedule new prefills. + remaining_waiting, prefills = self._schedule_prefills( + self.waiting, budget, curr_loras, enable_chunking=True) + + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting = remaining_waiting + self.waiting.extendleft(running_scheduled.preempted) + # Update new running requests. + self.running = remaining_running + self.running.extend([s.seq_group for s in prefills.seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.prefill_seq_groups]) + # Update swapped requests. + self.swapped = remaining_swapped + self.swapped.extend(running_scheduled.swapped_out) + + return SchedulerOutputs( + scheduled_seq_groups=(prefills.seq_groups + + running_scheduled.decode_seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.decode_seq_groups + + swapped_in.prefill_seq_groups), + num_prefill_groups=(len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)), + num_batched_tokens=budget.num_batched_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, + swapped_in.blocks_to_copy), + ignored_seq_groups=prefills.ignored_seq_groups, + num_lookahead_slots=(prefills.num_lookahead_slots + + running_scheduled.num_lookahead_slots + + swapped_in.num_lookahead_slots), + ) + + def _schedule(self) -> SchedulerOutputs: + """Schedule queued requests.""" + if self.scheduler_config.chunked_prefill_enabled: + return self._schedule_chunked_prefill() + else: + return self._schedule_default() + + def _can_append_slots(self, seq_group: SequenceGroup) -> bool: + """Determine whether or not we have enough space in the KV cache to + continue generation of the sequence group. + """ + # Appending slots only occurs in decoding. + is_prefill = False + + return self.block_manager.can_append_slots( + seq_group=seq_group, + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + ) + + def _can_swap_in(self, seq_group: SequenceGroup) -> bool: + # Swapping in is considered decode. + is_prefill = False + + return self.block_manager.can_swap_in( + seq_group=seq_group, + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), ) - return scheduler_outputs def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Schedule sequence groups. @@ -417,7 +884,8 @@ class Scheduler: # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): seq_group = scheduled_seq_group.seq_group token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) @@ -437,9 +905,12 @@ class Scheduler: self.block_manager.get_common_computed_block_ids( seq_group.get_seqs(status=SequenceStatus.RUNNING))) + # It assumes the scheduled_seq_groups is ordered by + # prefill < decoding. + is_prompt = i < scheduler_outputs.num_prefill_groups seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, - is_prompt=scheduler_outputs.prompt_run, + is_prompt=is_prompt, seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, @@ -452,7 +923,7 @@ class Scheduler: # the subsequent comms can still use delta, but # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.prompt_run else None, + if scheduler_outputs.num_prefill_groups > 0 else None, ) seq_group_metadata_list.append(seq_group_metadata) @@ -477,31 +948,43 @@ class Scheduler: self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) - def _allocate(self, seq_group: SequenceGroup) -> None: + def _allocate_and_set_running(self, seq_group: SequenceGroup, + num_new_tokens: int) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slot( + def _append_slots( self, seq_group: SequenceGroup, blocks_to_copy: Dict[int, List[int]], ) -> None: + """Appends new slots to the sequences in the given sequence group. + + Args: + seq_group (SequenceGroup): The sequence group containing the + sequences to append slots to. + blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source + block indices to lists of destination block indices. This + dictionary is updated with the new source and destination block + indices for the appended slots. + """ + num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - ret = self.block_manager.append_slot(seq) - if ret is not None: - src_block, dst_block = ret - if src_block in blocks_to_copy: - blocks_to_copy[src_block].append(dst_block) - else: - blocks_to_copy[src_block] = [dst_block] + cows = self.block_manager.append_slots(seq, num_lookahead_slots) + + for src, dests in cows.items(): + if src not in blocks_to_copy: + blocks_to_copy[src] = [] + blocks_to_copy[src].extend(dests) def _preempt( self, seq_group: SequenceGroup, blocks_to_swap_out: Dict[int, int], preemption_mode: Optional[PreemptionMode] = None, - ) -> None: + ) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences @@ -524,6 +1007,7 @@ class Scheduler: self._preempt_by_swap(seq_group, blocks_to_swap_out) else: raise AssertionError("Invalid preemption mode.") + return preemption_mode def _preempt_by_recompute( self, @@ -535,9 +1019,6 @@ class Scheduler: seq.status = SequenceStatus.WAITING self.free_seq(seq) seq.reset_state_for_recompute() - # NOTE: For FCFS, we insert the preempted sequence group to the front - # of the waiting queue. - self.waiting.appendleft(seq_group) def _preempt_by_swap( self, @@ -545,7 +1026,6 @@ class Scheduler: blocks_to_swap_out: Dict[int, int], ) -> None: self._swap_out(seq_group, blocks_to_swap_out) - self.swapped.append(seq_group) def _swap_in( self, @@ -588,3 +1068,39 @@ class Scheduler: else: passed_delay = True return passed_delay + + def _get_num_lookahead_slots(self, is_prefill: bool) -> int: + """The number of slots to allocate per sequence per step, beyond known + token ids. Speculative decoding uses these slots to store KV activations + of tokens which may or may not be accepted. + + Speculative decoding does not yet support prefill, so we do not perform + lookahead allocation for prefill. + """ + if is_prefill: + return 0 + + return self.scheduler_config.num_lookahead_slots + + def _get_num_new_tokens(self, seq_group: SequenceGroup, + status: SequenceStatus, enable_chunking: bool, + budget: SchedulingBudget) -> Tuple[int, bool]: + """Get the next new tokens to compute for a given sequence group + that's in a given `status`. + + The API could chunk the number of tokens to compute based on `budget` + if `enable_chunking` is True. If a sequence group has multiple + sequences (e.g., running beam search), it means it is in decoding + phase, so chunking doesn't happen. + """ + num_new_tokens = 0 + seqs = seq_group.get_seqs(status=status) + for seq in seqs: + num_new_tokens += seq.get_num_new_tokens() + # Chunk if a running request cannot fit in. + # If number of seq > 1, it means it is doing beam search in a + # decode phase. Do not chunk in that case. + if enable_chunking and len(seqs) == 1: + num_new_tokens = min(num_new_tokens, + budget.remaining_token_budget()) + return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 83ef7ca182c3d..d4b573992c06c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,10 +1,11 @@ import argparse import dataclasses from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, TokenizerPoolConfig, +from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig) from vllm.utils import str_to_int_tuple @@ -20,6 +21,7 @@ class EngineArgs: load_format: str = 'auto' dtype: str = 'auto' kv_cache_dtype: str = 'auto' + quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -53,17 +55,22 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False - - forced_num_gpu_blocks: Optional[int] = None + num_gpu_blocks_override: Optional[int] = None + num_lookahead_slots: int = 0 # Related to Vision-language models such as llava image_input_type: Optional[str] = None image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None + scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False + # Speculative decoding configuration. + speculative_model: Optional[str] = None + num_speculative_tokens: Optional[int] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -153,11 +160,23 @@ class EngineArgs: parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8_e5m2'], + choices=['auto', 'fp8'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' - 'data type. Note FP8 is not supported when cuda version is ' - 'lower than 11.8.') + 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' + 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' + 'supported for common inference criteria. ') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache ' + 'scaling factors. This should generally be supplied, when ' + 'KV cache dtype is FP8. Otherwise, KV cache scaling factors ' + 'default to 1.0, which may cause accuracy issues. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version' + 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' + 'supported for common inference criteria. ') parser.add_argument('--max-model-len', type=int, default=EngineArgs.max_model_len, @@ -202,6 +221,14 @@ class EngineArgs: parser.add_argument('--use-v2-block-manager', action='store_true', help='Use BlockSpaceMangerV2') + parser.add_argument( + '--num-lookahead-slots', + type=int, + default=EngineArgs.num_lookahead_slots, + help='Experimental scheduling config necessary for ' + 'speculative decoding. This will be replaced by ' + 'speculative config in the future; it is present ' + 'to enable correctness tests until then.') parser.add_argument('--seed', type=int, @@ -219,7 +246,7 @@ class EngineArgs: 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') parser.add_argument( - '--forced-num-gpu-blocks', + '--num-gpu-blocks-override', type=int, default=None, help='If specified, ignore GPU profiling result and use this number' @@ -324,7 +351,7 @@ class EngineArgs: parser.add_argument("--device", type=str, default=EngineArgs.device, - choices=["auto", "cuda", "neuron"], + choices=["auto", "cuda", "neuron", "cpu"], help='Device type for vLLM execution.') # Related to Vision-language models such as llava parser.add_argument( @@ -363,6 +390,20 @@ class EngineArgs: default=False, help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') + + parser.add_argument( + '--speculative-model', + type=str, + default=None, + help= + 'The name of the draft model to be used in speculative decoding.') + + parser.add_argument( + '--num-speculative-tokens', + type=int, + default=None, + help='The number of speculative tokens to sample from ' + 'the draft model in speculative decoding') return parser @classmethod @@ -373,23 +414,19 @@ class EngineArgs: engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) return engine_args - def create_engine_configs( - self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, - DeviceConfig, Optional[LoRAConfig], - Optional[VisionLanguageConfig]]: + def create_engine_config(self, ) -> EngineConfig: device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs) + self.quantization_param_path, self.enforce_eager, + self.max_context_len_to_capture, self.max_logprobs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - self.forced_num_gpu_blocks, + self.num_gpu_blocks_override, model_config.get_sliding_window(), self.enable_prefix_caching) parallel_config = ParallelConfig( @@ -401,11 +438,23 @@ class EngineArgs: self.tokenizer_pool_type, self.tokenizer_pool_extra_config, ), self.ray_workers_use_nsight) + + speculative_config = SpeculativeConfig.maybe_create_spec_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + target_dtype=self.dtype, + speculative_model=self.speculative_model, + num_speculative_tokens=self.num_speculative_tokens, + ) + scheduler_config = SchedulerConfig( self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, self.use_v2_block_manager, + num_lookahead_slots=(self.num_lookahead_slots + if speculative_config is None else + speculative_config.num_lookahead_slots), delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, ) @@ -433,8 +482,14 @@ class EngineArgs: else: vision_language_config = None - return (model_config, cache_config, parallel_config, scheduler_config, - device_config, lora_config, vision_language_config) + return EngineConfig(model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + speculative_config=speculative_config) @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 2e6f5d69a0420..f610495135121 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -328,28 +328,27 @@ class AsyncLLMEngine: ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. - engine_configs = engine_args.create_engine_configs() - parallel_config = engine_configs[2] - device_config = engine_configs[4] + engine_config = engine_args.create_engine_config() - if device_config.device_type == "neuron": + if engine_config.device_config.device_type == "neuron": raise NotImplementedError("Neuron is not supported for " "async engine yet.") - elif parallel_config.worker_use_ray or engine_args.engine_use_ray: - initialize_ray_cluster(parallel_config) + elif (engine_config.parallel_config.worker_use_ray + or engine_args.engine_use_ray): + initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync else: - assert parallel_config.world_size == 1, ( + assert engine_config.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync # Create the async LLM engine. engine = cls( - parallel_config.worker_use_ray, + engine_config.parallel_config.worker_use_ray, engine_args.engine_use_ray, - *engine_configs, - executor_class, + **engine_config.to_dict(), + executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, max_log_len=engine_args.max_log_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 66d75a9ffdf89..8173c49394f9c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,7 +5,8 @@ from transformers import PreTrainedTokenizer import vllm from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -13,7 +14,6 @@ from vllm.engine.ray_utils import initialize_ray_cluster from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.model_loader import get_architecture_class_name from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, @@ -53,6 +53,11 @@ class LLMEngine: parallel_config: The configuration related to distributed execution. scheduler_config: The configuration related to the request scheduler. device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + vision_language_config (Optional): The configuration related to vision + language models. + speculative_config (Optional): The configuration related to speculative + decoding. executor_class: The model executor class for managing distributed execution. log_stats: Whether to log statistics. @@ -67,7 +72,8 @@ class LLMEngine: scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - vision_language_config: Optional["VisionLanguageConfig"], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -75,6 +81,7 @@ class LLMEngine: logger.info( f"Initializing an LLM engine (v{vllm.__version__}) with config: " f"model={model_config.model!r}, " + f"speculative_config={speculative_config!r}, " f"tokenizer={model_config.tokenizer!r}, " f"tokenizer_mode={model_config.tokenizer_mode}, " f"revision={model_config.revision}, " @@ -90,6 +97,7 @@ class LLMEngine: f"quantization={model_config.quantization}, " f"enforce_eager={model_config.enforce_eager}, " f"kv_cache_dtype={cache_config.cache_dtype}, " + f"quantization_param_path={model_config.quantization_param_path}, " f"device_config={device_config.device}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -101,20 +109,30 @@ class LLMEngine: self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.speculative_config = speculative_config self.log_stats = log_stats - self._verify_args() self._init_tokenizer() self.detokenizer = Detokenizer(self.tokenizer) self.seq_counter = Counter() - self.model_executor = executor_class(model_config, cache_config, - parallel_config, scheduler_config, - device_config, lora_config, - vision_language_config) + self.model_executor = executor_class( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + speculative_config=speculative_config, + ) + + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) usage_message.report_usage( get_architecture_class_name(model_config), usage_context, @@ -162,6 +180,26 @@ class LLMEngine: labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info(f"Overriding {num_gpu_blocks=} with " + f"{num_gpu_blocks_override=}") + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + @classmethod def from_engine_args( cls, @@ -170,30 +208,31 @@ class LLMEngine: ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. - engine_configs = engine_args.create_engine_configs() - parallel_config = engine_configs[2] - device_config = engine_configs[4] + engine_config = engine_args.create_engine_config() # Initialize the cluster and specify the executor class. - if device_config.device_type == "neuron": + if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor - elif device_config.device_type == "tpu": + elif engine_config.device_config.device_type == "tpu": from vllm.executor.tpu_executor import TPUExecutor executor_class = TPUExecutor - elif parallel_config.worker_use_ray: - initialize_ray_cluster(parallel_config) + elif engine_config.device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + executor_class = CPUExecutor + elif engine_config.parallel_config.worker_use_ray: + initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor else: - assert parallel_config.world_size == 1, ( + assert engine_config.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor # Create the LLM engine. engine = cls( - *engine_configs, + **engine_config.to_dict(), executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, @@ -418,7 +457,7 @@ class LLMEngine: # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None: + if prompt_logprobs is not None and seq_group.sampling_params.detokenize: self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) seq_group.prompt_logprobs = prompt_logprobs @@ -464,8 +503,9 @@ class LLMEngine: child_seqs.append((parent, parent)) for seq, _ in child_seqs: - self.detokenizer.decode_sequence_inplace(seq, - seq_group.sampling_params) + if seq_group.sampling_params.detokenize: + self.detokenizer.decode_sequence_inplace( + seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) # Non-beam search case @@ -592,11 +632,10 @@ class LLMEngine: now = time.time() # Update the scheduled sequence groups with the model outputs. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): seq_group = scheduled_seq_group.seq_group - token_chunk_size = scheduled_seq_group.token_chunk_size - seq_group.update_num_computed_tokens(token_chunk_size) + seq_group.update_num_computed_tokens( + scheduled_seq_group.token_chunk_size) self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -715,7 +754,7 @@ class LLMEngine: time_per_output_tokens = [] time_e2e_requests = [] if scheduler_outputs is not None: - prompt_run = scheduler_outputs.prompt_run + prompt_run = scheduler_outputs.num_prefill_groups > 0 # Number of Tokens. if prompt_run: @@ -777,12 +816,13 @@ class LLMEngine: if seq.get_output_len() < sampling_params.min_tokens: return - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return + if sampling_params.detokenize: + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return last_token_id = seq.get_last_token_id() if last_token_id in sampling_params.stop_token_ids: stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e550943c88725..32282bfd8d12b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -127,7 +127,8 @@ if __name__ == "__main__": @app.middleware("http") async def authentication(request: Request, call_next): - if not request.url.path.startswith("/v1"): + root_path = "" if args.root_path is None else args.root_path + if not request.url.path.startswith(f"{root_path}/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: return JSONResponse(content={"error": "Unauthorized"}, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0b691feb8483f..f94d22d279cc4 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -4,7 +4,7 @@ import time from typing import Dict, List, Literal, Optional, Union import torch -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, conint, model_validator from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -229,6 +229,7 @@ class CompletionRequest(BaseModel): min_tokens: Optional[int] = 0 skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True + truncate_prompt_tokens: Optional[conint(ge=1)] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -309,6 +310,7 @@ class CompletionRequest(BaseModel): include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, logits_processors=logits_processors, + truncate_prompt_tokens=self.truncate_prompt_tokens, ) @model_validator(mode="before") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 3d1b16f528170..06e7a9225fefb 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing): for i, prompt in enumerate(prompts): if prompt_is_tokens: input_ids = self._validate_prompt_and_tokenize( - request, prompt_ids=prompt) + request, + prompt_ids=prompt, + truncate_prompt_tokens=sampling_params. + truncate_prompt_tokens) else: input_ids = self._validate_prompt_and_tokenize( - request, prompt=prompt) + request, + prompt=prompt, + truncate_prompt_tokens=sampling_params. + truncate_prompt_tokens) generators.append( self.engine.generate(prompt, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9dbd1750e631a..8f69388c0251e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from http import HTTPStatus from typing import Dict, List, Optional, Union +from pydantic import conint + from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, ErrorResponse, @@ -66,7 +68,8 @@ class OpenAIServing: self.tokenizer = get_tokenizer( engine_model_config.tokenizer, tokenizer_mode=engine_model_config.tokenizer_mode, - trust_remote_code=engine_model_config.trust_remote_code) + trust_remote_code=engine_model_config.trust_remote_code, + truncation_side="left") async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" @@ -164,15 +167,26 @@ class OpenAIServing: self, request: Union[ChatCompletionRequest, CompletionRequest], prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None) -> List[int]: + prompt_ids: Optional[List[int]] = None, + truncate_prompt_tokens: Optional[conint(ge=1)] = None + ) -> List[int]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): raise ValueError( "Only one of prompt or prompt_ids should be provided.") - input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( - prompt).input_ids + if prompt_ids is None: + tokenizer_kwargs = {} if truncate_prompt_tokens is None else { + "truncation": True, + "max_length": truncate_prompt_tokens, + } + input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids + elif truncate_prompt_tokens is not None: + input_ids = prompt_ids[-truncate_prompt_tokens:] + else: + input_ids = prompt_ids + token_num = len(input_ids) if request.max_tokens is None: diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py new file mode 100644 index 0000000000000..2bf97338da0ed --- /dev/null +++ b/vllm/executor/cpu_executor.py @@ -0,0 +1,140 @@ +import os +from typing import Dict, List, Optional + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import get_distributed_init_method, get_ip, get_open_port + +logger = init_logger(__name__) + + +class CPUExecutor(ExecutorBase): + + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: + assert device_config.device_type == "cpu" + assert lora_config is None, "cpu backend doesn't support LoRA" + model_config = _verify_and_get_model_config(model_config) + cache_config = _verify_and_get_cache_config(cache_config) + + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + # Instantiate the worker and load the model to CPU. + self._init_worker() + + def _init_worker(self): + from vllm.worker.cpu_worker import CPUWorker + + assert self.parallel_config.world_size == 1, ( + "CPUExecutor only supports single CPU socket currently.") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = CPUWorker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info(f"# CPU blocks: {num_cpu_blocks}") + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + output = self.driver_worker.execute_model( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.driver_worker.remove_lora(lora_id) + + def list_loras(self) -> List[int]: + return self.driver_worker.list_loras() + + def check_health(self) -> None: + # CPUExecutor will always be healthy as long as + # it's running. + return + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype == torch.float16: + logger.warning("float16 is not supported on CPU, casting to bfloat16.") + config.dtype = torch.bfloat16 + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on CPU, fallback to the eager " + "mode.") + config.enforce_eager = True + return config + + +def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: + _GB = 1 << 30 + if config.enable_prefix_caching: + logger.warning("Prefix caching is not supported on CPU, disable it.") + config.enable_prefix_caching = False + + kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0") + kv_cache_space = int(kv_cache_space_str) + + if kv_cache_space >= 0: + if kv_cache_space == 0: + config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore + logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " + "for CPU backend is not set, using 4 by default.") + else: + config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore + else: + raise RuntimeError( + "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + return config diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 55180d6110b6b..c18edd75d7a4d 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -25,9 +26,33 @@ class ExecutorBase(ABC): device_config: DeviceConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], ) -> None: raise NotImplementedError + @abstractmethod + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + Normally, this should simply delegate to the underlying Worker. Some + ExecutorBase may require modification of the result, e.g. to ensure the + selected cache sizes are compatible with all workers. + + Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + @abstractmethod def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index adbc4cb703f67..80ca5cb7367c5 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,9 +1,9 @@ from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.executor.utils import check_block_size_valid from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -24,6 +24,7 @@ class GPUExecutor(ExecutorBase): device_config: DeviceConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -33,12 +34,12 @@ class GPUExecutor(ExecutorBase): self.device_config = device_config self.vision_language_config = vision_language_config + assert (not speculative_config + ), "Speculative decoding not yet supported for GPU backend" + # Instantiate the worker and load the model to GPU. self._init_worker() - # Profile the memory usage and initialize the cache. - self._init_cache() - def _init_worker(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -50,61 +51,37 @@ class GPUExecutor(ExecutorBase): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) self.driver_worker.init_device() self.driver_worker.load_model() - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. - - The engine first profiles the existing memory usage. - Then, it allocates the remaining memory for KV blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_gpu_blocks, num_cpu_blocks = ( - self.driver_worker.profile_num_available_blocks( - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config. - gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - )) - - if self.cache_config.forced_num_gpu_blocks is not None: - forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks - logger.info(f"Replacing profiled {num_gpu_blocks=} with " - f"{forced_num_gpu_blocks=}") - num_gpu_blocks = forced_num_gpu_blocks + return self.driver_worker.determine_num_available_blocks() + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") - check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - # Initialize the cache. - self.driver_worker.init_cache_engine(cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self.driver_worker.warm_up_model() + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index f64c411cc6cb0..57436a85cfa27 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,7 +1,8 @@ from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -21,19 +22,15 @@ class NeuronExecutor(ExecutorBase): device_config: DeviceConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config - self.cache_config = cache_config assert lora_config is None, "LoRA is not supported for Neuron backend." self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config - - # Set the number of GPU blocks to be the same as the maximum number of - # sequences that can be processed in a single batch. This is equivalent - # to schedule without PagedAttention. - self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs - self.cache_config.num_cpu_blocks = 0 + assert (not speculative_config + ), "Speculative decoding not yet supported for Neuron backend." # Instantiate the worker and load the model to the device. self._init_worker() @@ -50,6 +47,18 @@ class NeuronExecutor(ExecutorBase): self.driver_worker.init_device() self.driver_worker.load_model() + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], @@ -64,16 +73,13 @@ class NeuronExecutor(ExecutorBase): return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.remove_lora(lora_id) def list_loras(self) -> List[int]: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.list_loras() def check_health(self) -> None: # NeuronExecutor will always be healthy as long as diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 8f80c20738bba..6c0ccd7e64c90 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -6,10 +6,10 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.executor.utils import check_block_size_valid from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -41,6 +41,7 @@ class RayGPUExecutor(ExecutorBase): device_config: DeviceConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -49,6 +50,8 @@ class RayGPUExecutor(ExecutorBase): self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + assert (not speculative_config + ), "Speculative decoding not yet supported for RayGPU backend." assert self.parallel_config.worker_use_ray placement_group = self.parallel_config.placement_group @@ -61,9 +64,6 @@ class RayGPUExecutor(ExecutorBase): # Create the parallel GPU workers. self._init_workers_ray(placement_group) - # Profile the memory usage and initialize the cache. - self._init_cache() - self.forward_dag = None if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() @@ -150,7 +150,8 @@ class RayGPUExecutor(ExecutorBase): scheduler_config = copy.deepcopy(self.scheduler_config) device_config = copy.deepcopy(self.device_config) lora_config = copy.deepcopy(self.lora_config) - kv_cache_dtype = self.cache_config.cache_dtype + cache_config = copy.deepcopy(self.cache_config) + vision_language_config = copy.deepcopy(self.vision_language_config) # Initialize the actual workers with the Worker class. for rank, (worker, (node_id, _)) in enumerate( @@ -160,31 +161,32 @@ class RayGPUExecutor(ExecutorBase): local_rank = node_workers[node_id].index(rank) worker.init_worker.remote( lambda rank=rank, local_rank=local_rank: Worker( - model_config, - parallel_config, - scheduler_config, - device_config, - local_rank, - rank, - distributed_init_method, + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + cache_config=cache_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, lora_config=lora_config, - kv_cache_dtype=kv_cache_dtype, + vision_language_config=vision_language_config, )) # Initialize the driver worker with the Worker class. driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - driver_local_rank, - driver_rank, - distributed_init_method, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=driver_local_rank, + rank=driver_rank, + distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - kv_cache_dtype=kv_cache_dtype, is_driver_worker=True, ) @@ -195,35 +197,18 @@ class RayGPUExecutor(ExecutorBase): max_parallel_loading_workers, ) - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks. - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - More details can be found in the - :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method - from class :class:`~vllm.worker.Worker`. + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. - Afterwards, as there may be multiple workers, - we take the minimum number of blocks across all workers - to ensure this can be applied to all of them. - - Finally, the engine will initialize the KV cache - with the calculated number of blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers( - "profile_num_available_blocks", - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config.gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - ) + num_blocks = self._run_workers("determine_num_available_blocks", ) # Since we use a shared centralized controller, we take the minimum # number of blocks across all workers to make sure all the memory @@ -231,26 +216,25 @@ class RayGPUExecutor(ExecutorBase): num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) - if self.cache_config.forced_num_gpu_blocks is not None: - forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks - logger.info(f"Replacing profiled {num_gpu_blocks=} with " - f"{forced_num_gpu_blocks=}") - num_gpu_blocks = forced_num_gpu_blocks + return num_gpu_blocks, num_cpu_blocks + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") - check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len) - self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - # Initialize the cache. - self._run_workers("init_cache_engine", cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self._run_workers("warm_up_model") + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/utils.py b/vllm/executor/utils.py deleted file mode 100644 index 44976696a77c6..0000000000000 --- a/vllm/executor/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None: - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 945917a5aa86b..62f1502458008 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -191,6 +191,7 @@ class LoRAModel: def from_local_checkpoint( cls, lora_dir: str, + expected_lora_modules: List[str], lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, @@ -206,6 +207,20 @@ class LoRAModel: lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") + with open(lora_config_path) as f: + config = json.load(f) + target_modules = config["target_modules"] + unexpected_modules = [] + for module in target_modules: + if module not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of expected_lora_modules + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") if os.path.isfile(lora_tensor_path): tensors = safetensors.torch.load_file(lora_tensor_path) elif os.path.isfile(lora_bin_file_path): @@ -220,8 +235,6 @@ class LoRAModel: elif os.path.isfile(new_embeddings_bin_file_path): embeddings = torch.load(new_embeddings_bin_file_path) - with open(lora_config_path) as f: - config = json.load(f) rank = config["r"] lora_alpha = config["lora_alpha"] return cls.from_lora_tensors( diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3224b3a9e3eb0..a0868defbd3ca 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -136,8 +136,19 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: try: + model = self._lora_manager.model + supported_lora_modules = model.supported_lora_modules + packed_modules_mapping = model.packed_modules_mapping + expected_lora_modules = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) lora = self._lora_model_cls.from_local_checkpoint( lora_request.lora_local_path, + expected_lora_modules, lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 28e8f6bb7e638..ec531f79ced52 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -86,8 +86,16 @@ def _apply_logits_processors( ) -> torch.Tensor: logits_row_idx = 0 found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group logits_processors = sampling_params.logits_processors + # handle prompt_logprobs by skipping rows in logits added for + # the prompt tokens (prompt logprobs are not processed) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + assert len(seq_ids) == 1 + logits_row_idx += sampling_metadata.prompt_lens[i] - 1 + if logits_processors: found_logits_processors = True for seq_id in seq_ids: @@ -100,5 +108,6 @@ def _apply_logits_processors( else: logits_row_idx += len(seq_ids) if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly assert logits_row_idx == logits.shape[0] return logits diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py new file mode 100644 index 0000000000000..a26c524787a0b --- /dev/null +++ b/vllm/model_executor/layers/quantization/schema.py @@ -0,0 +1,84 @@ +""" +This file contains the Pydantic schemas for various quantization-related +parameters. When a relevant quantization technique is specified, these +parameters are loaded in the form of a JSON alongside the model weights +and augment the model with additional information needed for use of that +technique. The format of this JSON should be specified by one or more +schemas contained here. + +For example, when the KV cache is quantized to FP8-E4M3 (currently only +possible on ROCm), the model can be optionally augmented with KV cache +scaling factors. +""" + +from typing import Dict, Optional + +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator + + +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!") + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}.") + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}.") + for i in range(tp_size): + assert i in self.scaling_factor, ( + f"KV cache scales map for TP rank {i} not found.") + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}.") + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!") + return self diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b5c7e44de619c..17fc970568042 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,5 +1,5 @@ import importlib -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type import torch.nn as nn @@ -41,6 +41,7 @@ _MODELS = { # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), @@ -55,6 +56,10 @@ _MODELS = { "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS = [] @@ -74,6 +79,8 @@ class ModelRegistry: @staticmethod def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] if model_arch not in _MODELS: return None if is_hip(): @@ -95,6 +102,16 @@ class ModelRegistry: def get_supported_archs() -> List[str]: return list(_MODELS.keys()) + @staticmethod + def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + f"Model architecture {model_arch} is already registered, " + "and will be overwritten by the new model " + f"class {model_cls.__name__}.") + global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls + __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index ee6d36f69506f..29ba3844eb11d 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -25,6 +25,7 @@ from typing import List, Optional, Tuple import torch import torch.utils.checkpoint from torch import nn +from torch.nn.parameter import Parameter from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata @@ -39,8 +40,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -48,11 +50,11 @@ from vllm.sequence import SamplerOutput class LayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-5, bias=False): + def __init__(self, param_shape=None, eps=1e-5): super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None + self.weight = nn.Parameter(torch.ones(param_shape)) self.variance_epsilon = eps + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) def forward(self, hidden_states, residuals=None): input_dtype = hidden_states.dtype @@ -62,10 +64,20 @@ class LayerNorm(nn.Module): hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon) hidden_states = self.weight.to(torch.float32) * hidden_states - if self.bias is not None: - hidden_states = hidden_states + self.bias.to(torch.float32) return hidden_states.to(input_dtype), residuals + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_dim = 0 if param.dim() != 1 else None + param_data = param.data + if shard_dim is not None: + shard_size = param_data.shape[shard_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, + shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere class CohereMLP(nn.Module): @@ -128,9 +140,12 @@ class CohereAttention(nn.Module): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.max_position_embeddings = config.max_position_embeddings + self.max_position_embeddings = getattr( + config, "model_max_length", None) or getattr( + config, "max_position_embeddings", 8192) self.rope_theta = config.rope_theta self.rope_scaling = getattr(config, "rope_scaling", None) + self.use_qk_norm = getattr(config, "use_qk_norm", False) self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, @@ -159,6 +174,22 @@ class CohereAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, ) + if self.use_qk_norm: + self.q_norm = LayerNorm(param_shape=(self.num_heads, + self.head_dim), + eps=config.layer_norm_eps) + self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, + self.head_dim), + eps=config.layer_norm_eps) + + def _apply_qk_norm(self, q, k): + q = q.view(*q.shape[:-1], -1, self.head_dim) + k = k.view(*k.shape[:-1], -1, self.head_dim) + q, _ = self.q_norm(q) + k, _ = self.k_norm(k) + q = q.view(*q.shape[:-2], -1) + k = k.view(*k.shape[:-2], -1) + return q, k def forward( self, @@ -169,6 +200,8 @@ class CohereAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -186,7 +219,7 @@ class CohereDecoderLayer(nn.Module): self.self_attn = CohereAttention(config, linear_method=linear_method) self.mlp = CohereMLP(config, linear_method=linear_method) - self.input_layernorm = LayerNorm(config.hidden_size, + self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) def forward( @@ -229,7 +262,8 @@ class CohereModel(nn.Module): CohereDecoderLayer(config, linear_method=linear_method) for _ in range(config.num_hidden_layers) ]) - self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm = LayerNorm(param_shape=(config.hidden_size), + eps=config.layer_norm_eps) def forward( self, @@ -317,11 +351,21 @@ class CohereForCausalLM(nn.Module): if shard_name not in name: continue name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 673900487cc96..a5b5d717d9846 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -274,6 +274,11 @@ class GPTNeoXForCausalLM(nn.Module): if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using OpenRLHF may include + # these tensors in the checkpoint. Skip them. + continue param = params_dict[name] if "query_key_value" in name: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 57857deb9eb86..72fe21df67d8a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -41,11 +41,13 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) + hf_model_weights_iterator, + kv_cache_scales_loader) from vllm.sequence import SamplerOutput +from vllm.utils import is_hip class LlamaMLP(nn.Module): @@ -115,6 +117,15 @@ class LlamaAttention(nn.Module): self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + # This will be overwritten by model initialization if we are using it. + # N.B. currently we only support per tensor scalar scaling factors + # & only applicable to ROCm (AMD GPU). + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + self.kv_scale = 1.0 + self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -153,7 +164,8 @@ class LlamaAttention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata, + self.kv_scale) output, _ = self.o_proj(attn_output) return output @@ -172,6 +184,10 @@ class LlamaDecoderLayer(nn.Module): max_position_embeddings = getattr(config, "max_position_embeddings", 8192) sliding_window = getattr(config, "sliding_window", None) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -181,7 +197,7 @@ class LlamaDecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, linear_method=linear_method, - bias=getattr(config, "bias", False), + bias=attention_bias, sliding_window=sliding_window, ) self.mlp = LlamaMLP( @@ -402,3 +418,27 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py new file mode 100644 index 0000000000000..99d1b4eb97bb8 --- /dev/null +++ b/vllm/model_executor/models/minicpm.py @@ -0,0 +1,537 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM model compatible with HuggingFace weights.""" +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import LoRAConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + + +class MiniCPMMoE(nn.Module): + """A tensor-parallel MoE implementation that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + linear_method=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=True, + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class MiniCPMMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiniCPMAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + # set rope as fp32 instead of bf16 + self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache( + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + orig_dtype = q.dtype + q, k = q.float(), k.float() + q, k = self.rotary_emb(positions, q, k) + q, k = q.to(orig_dtype), k.to(orig_dtype) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPMDecoderLayer(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = MiniCPMAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.num_experts = getattr(self.config, "num_experts", 0) + if self.num_experts == 0: + self.mlp = MiniCPMMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + else: + self.mlp = MiniCPMMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + return hidden_states, None + + +class MiniCPMModel(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + MiniCPMDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding = self.embed_tokens(input_ids) + return embedding * self.config.scale_emb + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + residual, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MiniCPMForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.num_experts = getattr(self.config, "num_experts", 0) + self.linear_method = linear_method + self.model = MiniCPMModel(config, + linear_method, + lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + if not self.config.tie_word_embeddings: + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.scale_width = self.config.hidden_size / self.config.dim_model_base + + self.logits_processor = LogitsProcessor(unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + hidden_states = hidden_states / self.scale_width + if self.config.tie_word_embeddings: + lm_head_weight = self.model.embed_tokens.weight + else: + lm_head_weight = self.lm_head.weight + logits = self.logits_processor(lm_head_weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 459f11d1d35a7..611a48a9aad2b 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -39,14 +39,15 @@ from typing import List, Optional, Tuple import torch -import torch.nn.functional as F # this model must need this dependency from hf_olmo import OLMoConfig from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -62,17 +63,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.sequence import SamplerOutput -class SwiGLU(nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - @property - def output_multiplier(self) -> float: - return 0.5 - - class OlmoAttention(nn.Module): """ This is the attention block where the output is computed as @@ -174,17 +164,16 @@ class OlmoMLP(nn.Module): bias=False) # Feed-forward input projection. - self.ff_proj = ColumnParallelLinear( + self.ff_proj = MergedColumnParallelLinear( config.d_model, - self.hidden_size, + [self.hidden_size // 2] * 2, bias=config.include_bias, linear_method=linear_method, ) # Activation function. - # self.act = SiluAndMul() - # self.act.output_multiplier = 0.5 - self.act = SwiGLU() + self.act = SiluAndMul() + self.act.output_multiplier = 0.5 assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 # Feed-forward output projection. @@ -374,8 +363,12 @@ class OLMoForCausalLM(nn.Module): if ".att" in name: name = name.replace(".att", ".attn.att") # mlp - if ".ff" in name and "transformer.ff_out" not in name: - name = name.replace(".ff", ".mlp.ff") + if ".ff_proj" in name: + name = name.replace(".ff_proj", ".mlp.ff_proj") + # Reverse the weight for the MergeColumnParallelLinear + loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1]) + if ".ff_out" in name and "transformer.ff_out" not in name: + name = name.replace(".ff_out", ".mlp.ff_out") # there is no bias in olmo param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index bcda5ebf8548b..3bbfa1bd5443a 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -4,6 +4,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" import contextlib +from typing import Optional import torch @@ -14,14 +15,59 @@ _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None +# when people blindly call `torch.distributed.all_reduce` etc, +# it will use this group. It is initialized with the `backend` +# parameter of `init_distributed_environment` below. +# Essentially, this is `torch.distributed.group.WORLD`. +# We leave a line here to note that this is device-specific. +# Note that this variable is not safe to use, because when users +# call `init_distributed_environment` first, and then destroy +# the process group themselves, this variable will keep a reference to the +# destroyed process group, which is not useful. +_DEVICE_WORLD_GROUP = None + +# duing `init_distributed_environment`, we will also initialize a +# group with `gloo` backend, to allow direct coordination between +# processes through the CPU. +_CPU_WORLD_GROUP = None + +# In summary, after calling `init_distributed_environment`, we will +# always have two groups: one for device-specific (and is the default) +# and one for CPU. All processes will be part of both groups. + # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None +def init_distributed_environment( + world_size: int, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + backend: str = "nccl", +): + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank) + global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP + _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD + ranks = list(range(torch.distributed.get_world_size())) + _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, + backend="gloo") + + def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, ) -> None: """ Initialize model parallel groups. @@ -48,6 +94,8 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend() if (world_size != tensor_model_parallel_size * pipeline_model_parallel_size): @@ -69,7 +117,7 @@ def initialize_model_parallel( for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - group = torch.distributed.new_group(ranks) + group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group @@ -80,7 +128,7 @@ def initialize_model_parallel( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) - group = torch.distributed.new_group(ranks) + group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: _PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks @@ -89,14 +137,17 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend() if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size) + pipeline_model_parallel_size, backend) return assert ( @@ -117,6 +168,12 @@ def model_parallel_is_initialized(): and _PIPELINE_MODEL_PARALLEL_GROUP is not None) +def get_cpu_world_group(): + """Get the CPU world group.""" + assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized") + return _CPU_WORLD_GROUP + + def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 2aed70f05e067..0a8bb860efa1c 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -21,6 +21,7 @@ import ctypes import datetime +import glob import os # ===================== import region ===================== @@ -34,18 +35,27 @@ logger = init_logger(__name__) so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") +# check if we have vllm-managed nccl +vllm_nccl_path = None +if torch.version.cuda is not None: + cuda_major = torch.version.cuda.split(".")[0] + path = os.path.expanduser( + f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*") + files = glob.glob(path) + vllm_nccl_path = files[0] if files else None + # manually load the nccl library if so_file: logger.info( f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") else: if torch.version.cuda is not None: - so_file = "libnccl.so.2" + so_file = vllm_nccl_path or "libnccl.so.2" elif torch.version.hip is not None: so_file = "librccl.so.1" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.debug(f"Loading nccl from library {so_file}") + logger.info(f"Loading nccl from library {so_file}") try: nccl = ctypes.CDLL(so_file) @@ -226,22 +236,25 @@ class NCCLCommunicator: if local_rank == -1: local_rank = self.rank self.local_rank = local_rank - torch.cuda.set_device(local_rank) - if rank == 0: + # don't use these args, as they can be -1 + # use `self.rank`, `self.local_rank` and `self.world_size` instead + del world_size, rank, local_rank + torch.cuda.set_device(self.local_rank) + if self.rank == 0: self.unique_id = ncclGetUniqueId() else: self.unique_id = NcclUniqueId() - tensor = torch.ByteTensor(list( - self.unique_id.internal)).cuda(local_rank) + tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( + self.local_rank) dist.broadcast(tensor, src=0) byte_list = tensor.cpu().tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte self.comm = ctypes.c_void_p() - result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size, - self.unique_id, rank) + result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, + self.unique_id, self.rank) assert result == 0 - self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}") + self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}") def all_reduce(self, tensor: torch.Tensor, @@ -261,4 +274,6 @@ class NCCLCommunicator: # `dist` module might have been already destroyed if hasattr(dist, 'destroy_process_group'): dist.destroy_process_group() - _c_ncclCommDestroy(self.comm) + # function might have been already destroyed + if _c_ncclCommDestroy is not None: + _c_ncclCommDestroy(self.comm) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 9181f298871db..0961478930d74 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -5,9 +5,10 @@ import hashlib import json import os from collections import defaultdict -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Tuple import filelock +import huggingface_hub.constants import numpy as np import torch from huggingface_hub import HfFileSystem, snapshot_download @@ -18,6 +19,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) +from vllm.model_executor.layers.quantization.schema import QuantParamSchema logger = init_logger(__name__) @@ -29,6 +31,21 @@ temp_dir = os.environ.get('TMPDIR') or os.environ.get( 'TEMP') or os.environ.get('TMP') or "/tmp/" +def enable_hf_transfer(): + """automatically activates hf_transfer + """ + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + class Disabledtqdm(tqdm): def __init__(self, *args, **kwargs): @@ -275,6 +292,46 @@ def hf_model_weights_iterator( torch.cuda.empty_cache() +def kv_cache_scales_loader( + filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, + model_type: Optional[str]) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + Keep this function in sync with the output of examples/fp8/extract_scales.py + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, + context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + + except FileNotFoundError: + logger.error(f"File or directory '{filename}' not found.") + except json.JSONDecodeError: + logger.error(f"Error decoding JSON in file '{filename}'.") + except Exception as e: + logger.error(f"An error occurred while reading '{filename}': {e}") + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning("Defaulting to KV cache scaling factors = 1.0 " + f"for all layers in TP rank {tp_rank} " + "as an error occurred during loading.") + return [] + + def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: """convert PySafeSlice object from safetensors to torch.Tensor diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 6f81ee31f84dd..4fdc3c6dedaef 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -5,6 +5,7 @@ from functools import cached_property from typing import Callable, List, Optional, Union import torch +from pydantic import conint _SAMPLING_EPS = 1e-5 @@ -88,11 +89,15 @@ class SamplingParams: log probability of the sampled token, so there may be up to `logprobs+1` elements in the response. prompt_logprobs: Number of log probabilities to return per prompt token. + detokenize: Whether to detokenize the output. Defaults to True. skip_special_tokens: Whether to skip special tokens in the output. spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. logits_processors: List of functions that modify logits based on previously generated tokens. + truncate_prompt_tokens: If set to an integer k, will use only the last k + tokens from the prompt (i.e., left truncation). Defaults to None + (i.e., no truncation). """ def __init__( @@ -118,9 +123,11 @@ class SamplingParams: min_tokens: int = 0, logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None, + detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, + truncate_prompt_tokens: Optional[conint(ge=1)] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -150,10 +157,15 @@ class SamplingParams: self.min_tokens = min_tokens self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs + # NOTE: This parameter is only exposed at the engine level for now. + # It is not exposed in the OpenAI API server, as the OpenAI API does + # not support returning only a list of token IDs. + self.detokenize = detokenize self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output + self.truncate_prompt_tokens = truncate_prompt_tokens self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -210,6 +222,14 @@ class SamplingParams: if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") + if (self.truncate_prompt_tokens is not None + and self.truncate_prompt_tokens < 1): + raise ValueError(f"truncate_prompt_tokens must be >= 1, " + f"got {self.truncate_prompt_tokens}") + if self.stop and not self.detokenize: + raise ValueError( + "stop strings are only supported when detokenize is True. " + "Set detokenize=True to use stop.") def _verify_beam_search(self) -> None: if self.best_of == 1: @@ -290,4 +310,5 @@ class SamplingParams: f"prompt_logprobs={self.prompt_logprobs}, " f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" - f"{self.spaces_between_special_tokens})") + f"{self.spaces_between_special_tokens}, " + f"truncate_prompt_tokens={self.truncate_prompt_tokens})") diff --git a/vllm/sequence.py b/vllm/sequence.py index a40f38f76d1c4..576bbe8c4f6c4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum): return finish_reason +class SequenceStage(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + + @dataclass class RequestMetrics: """Metrics associated with a request. @@ -115,6 +120,7 @@ class SequenceData: self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 + self._stage: SequenceStage = SequenceStage.PREFILL def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) @@ -136,16 +142,22 @@ class SequenceData: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens - def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int: + def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" self._num_computed_tokens += num_new_computed_tokens + assert self._num_computed_tokens <= self.get_len(), ( + self._num_computed_tokens, self.get_len()) + # If all tokens are computed, it means it is in decoding phase. + if self.get_num_uncomputed_tokens() == 0: + self._stage = SequenceStage.DECODE - def reset_num_computed_tokens(self) -> None: + def reset_state_for_recompute(self) -> None: """Reset the number of computed tokens from this sequence. It is supposed to be called when a sequence needs to be started from the beginning again (e.g., sequence is preempted). """ self._num_computed_tokens = 0 + self._stage = SequenceStage.PREFILL def get_num_uncomputed_tokens(self) -> int: """Return the number of prefil tokens that are not computed.""" @@ -165,6 +177,10 @@ class SequenceData: def get_output_token_ids(self) -> int: return self.output_token_ids + @property + def stage(self) -> SequenceStage: + return self._stage + def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self.prompt_token_ids}, " @@ -234,7 +250,7 @@ class Sequence: def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" - self.data.reset_num_computed_tokens() + self.data.reset_state_for_recompute() def _append_logical_block(self) -> None: block = LogicalTokenBlock( @@ -320,6 +336,23 @@ class Sequence: new_seq.seq_id = new_seq_id return new_seq + def get_num_new_tokens(self) -> int: + """Get the number of new tokens to be computed. + + Args: + remainig_token_budget: The remaining token budgets. + Returns: + The new number of tokens to be computed. I.e., 1 for decode, prompt + size for prefill. If there's not enough remainig_token_budget, it + can return the chunked number of new tokens. + """ + if self.data.stage == SequenceStage.DECODE: + return 1 + return self.data.get_num_uncomputed_tokens() + + def is_prefill(self) -> bool: + return self.data.stage == SequenceStage.PREFILL + def __repr__(self) -> str: return (f"Sequence(seq_id={self.seq_id}, " f"status={self.status.name}, " @@ -461,14 +494,14 @@ class SequenceGroup: def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" for seq in self.seqs_dict.values(): - seq.data.update_num_computed_tokens(num_new_computed_tokens) + if not seq.is_finished(): + seq.data.update_num_computed_tokens(num_new_computed_tokens) def get_num_uncomputed_tokens(self) -> int: - # All sequences in the group should have the same prompt, so the - # number of unfinished prefill tokens are the same across all - # sequences. - return list( - self.seqs_dict.values())[0].data.get_num_uncomputed_tokens() + num_uncomputed_tokens = 0 + for seq in self.get_seqs(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) @@ -497,6 +530,10 @@ class SequenceGroup: def is_finished(self) -> bool: return all(seq.is_finished() for seq in self.get_seqs()) + def is_prefill(self) -> bool: + # Every sequences should be in the same stage. + return self.get_seqs()[0].is_prefill() + def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " f"sampling_params={self.sampling_params}, " @@ -513,8 +550,8 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) - token_chunk_size: The number of tokens to be processed. None if - chunking is not required. + token_chunk_size: The number of tokens to be processed (per sequence). + None if chunking is not required. state: Internal state tied to this sequence group. lora_request: LoRA request. multi_modal_data: Multi modal data. diff --git a/vllm/spec_decode/__init__.py b/vllm/spec_decode/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 59f9d5b5107f3..885bf537568e3 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple import torch -from vllm.config import CacheConfig from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) @@ -15,9 +14,10 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoraNotSupportedWorkerBase -class SpecDecodeWorker: +class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. Speculative decoding reduces decoding per-token latency by using a proposal @@ -94,10 +94,7 @@ class SpecDecodeWorker: device=self.device, vocab_size=self._vocab_size) - def profile_num_available_blocks(self, block_size: int, - gpu_memory_utilization: float, - cpu_swap_space: int, - cache_dtype: str) -> Tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. This is done by profiling the scorer model (which is typically the @@ -106,27 +103,26 @@ class SpecDecodeWorker: such that the number of blocks is equal in both KV caches. """ num_gpu_blocks, num_cpu_blocks = ( - self.scorer_worker.profile_num_available_blocks( - block_size, gpu_memory_utilization, cpu_swap_space, - cache_dtype)) + self.scorer_worker.determine_num_available_blocks()) scorer_cache_block_size_bytes = ( - self.scorer_worker.get_cache_block_size_bytes( - block_size, cache_dtype)) + self.scorer_worker.get_cache_block_size_bytes()) proposer_cache_block_size_bytes = ( - self.proposer_worker.get_cache_block_size_bytes( - block_size, cache_dtype)) + self.proposer_worker.get_cache_block_size_bytes()) new_num_gpu_blocks = split_num_cache_blocks_evenly( scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, num_gpu_blocks) return new_num_gpu_blocks, num_cpu_blocks - def init_cache_engine(self, cache_config: CacheConfig): + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: """Initialize the cache engine of the scorer and proposer workers. """ - self.scorer_worker.init_cache_engine(cache_config) - self.proposer_worker.init_cache_engine(cache_config) + self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) @torch.inference_mode() def execute_model( @@ -351,6 +347,16 @@ class SpecDecodeWorker: def device(self): return self.scorer_worker.device + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes. + + This function is only used to compose workers within a SpecDecodeWorker. + We leave composing a SpecDecodeWorker within a SpecDecodeWorker + undefined for now, although it could be implemented in the future. + See https://arxiv.org/abs/2308.04623. + """ + raise NotImplementedError + def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, proposer_cache_block_size_bytes: int, diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 94e962e12e87b..bc220d3b8a430 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -1,8 +1,8 @@ import ray -from vllm.config import ParallelConfig +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized, init_distributed_environment) from vllm.utils import get_open_port -from vllm.worker.worker import init_distributed_environment def init_test_distributed_environment( @@ -12,15 +12,14 @@ def init_test_distributed_environment( distributed_init_port: str, local_rank: int = -1, ) -> None: - parallel_config = ParallelConfig(pipeline_parallel_size, - tensor_parallel_size, - worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - parallel_config, - rank, + world_size=pipeline_parallel_size * tensor_parallel_size, + rank=rank, distributed_init_method=distributed_init_method, local_rank=local_rank) + ensure_model_parallel_initialized(tensor_parallel_size, + pipeline_parallel_size) def multi_process_tensor_parallel( diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 419687e23b718..486c1938e1e10 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,10 +1,8 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.tokenizer import (convert_prompt_ids_to_tokens, - detokenize_incrementally) from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) @@ -148,10 +146,160 @@ class Detokenizer: ) sample_logprob.decoded_token = new_text - if seq.tokens is None: - seq.tokens = new_tokens - else: - seq.tokens.extend(new_tokens) + seq.tokens.extend(new_tokens) seq.prefix_offset = prefix_offset seq.read_offset = read_offset seq.output_text += new_decoded_token_text + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts = [] + current_sub_text = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # We do not need to convert the whole prompt to tokens. + # Offset a little more in case we have special tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], + skip_special_tokens=skip_special_tokens) + read_offset = len(new_tokens) + prefix_offset = max( + read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + return new_tokens, prefix_offset, read_offset + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int, + read_offset: int, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + + # If the new token id is out of bounds, return an empty string. + if new_token_id >= len(tokenizer): + new_tokens = [""] + else: + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) <= len(prefix_text) or new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + return new_tokens, "", prefix_offset, read_offset + + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 3bda3f419d8a2..e216a99af91f9 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -126,157 +126,3 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, get_lora_tokenizer_async = make_async(get_lora_tokenizer) - - -def _convert_tokens_to_string_with_added_encoders( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - output_tokens: List[str], - skip_special_tokens: bool, - spaces_between_special_tokens: bool, -) -> str: - # Adapted from - # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 - # NOTE(woosuk): The following code is slow because it runs a for loop over - # the output_tokens. In Python, running a for loop over a list can be slow - # even when the loop body is very simple. - sub_texts = [] - current_sub_text = [] - all_special_tokens = set(tokenizer.all_special_tokens) - for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: - continue - if token in tokenizer.get_added_vocab(): - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] - sub_texts.append(token) - else: - current_sub_text.append(token) - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - if spaces_between_special_tokens: - return " ".join(sub_texts) - else: - return "".join(sub_texts) - - -# 5 is an arbitrary value that should work for all -# tokenizers (bigger = more conservative). -INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 - - -def convert_prompt_ids_to_tokens( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - prompt_ids: List[int], - skip_special_tokens: bool = False, -) -> Tuple[List[str], int, int]: - """Converts the prompt ids to tokens and returns the tokens and offsets - for incremental detokenization. - - Note that not all tokens are converted to strings. Only the tokens that - are necessary for incremental detokenization are converted to strings. - """ - # Offset a little more in case we have special tokens. - prefix_offset = max( - len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0) - # We do not need to convert the whole prompt to tokens. - new_tokens = tokenizer.convert_ids_to_tokens( - prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens) - prefix_offset = max( - len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) - read_offset = len(new_tokens) - return new_tokens, prefix_offset, read_offset - - -# Based on -# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 -# under Apache 2.0 license -def detokenize_incrementally( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - all_input_ids: List[int], - prev_tokens: Optional[List[str]], - prefix_offset: int, - read_offset: int, - skip_special_tokens: bool = False, - spaces_between_special_tokens: bool = True, -) -> Tuple[List[str], str, int, int]: - """Detokenizes the input ids incrementally and returns the new tokens - and the new text. - - If `prev_tokens` is None, this function will convert the input ids to - tokens and return the tokens and the new text. Otherwise, it will return the - new tokens and the new text. - - This function will also return the new prefix offset and the new read - offset to be used in the next iteration. - - The offsets are necessary to defeat cleanup algorithms in the decode which - decide to add a space or not depending on the surrounding ids. - - Args: - tokenizer: The tokenizer to use. - all_input_ids: The input ids. The last id is the new token id. - prev_tokens: The previous tokens. If None, this function will convert - the input ids to tokens and return the tokens and the new text. - prefix_offset: The prefix offset. - read_offset: The read offset. - skip_special_tokens: Whether to skip special tokens. - spaces_between_special_tokens: Whether to add spaces between special - tokens. - """ - new_token_id = all_input_ids[-1] - # This is the first iteration for this sequence - is_first_iter = prev_tokens is None - if is_first_iter: - (prev_tokens, prefix_offset, - read_offset) = convert_prompt_ids_to_tokens( - tokenizer, - all_input_ids[:-1], - skip_special_tokens=skip_special_tokens) - - # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: - # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) - output_tokens = prev_tokens + new_tokens - - # If this is the first iteration, return all tokens. - if is_first_iter: - new_tokens = output_tokens - - # The prefix text is necessary only to defeat cleanup algorithms in - # the decode which decide to add a space or not depending on the - # surrounding ids. - if tokenizer.is_fast or not tokenizer.get_added_vocab(): - prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset]) - new_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:]) - else: - prefix_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:read_offset], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - new_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - - if len(new_text) > len(prefix_text) and not new_text.endswith("�"): - # utf-8 char at the end means it's a potential unfinished byte sequence - # from byte fallback tokenization. - # If it's in the middle, it's probably a real invalid id generated - # by the model - new_text = new_text[len(prefix_text):] - return new_tokens, new_text, read_offset, len(output_tokens) - else: - return new_tokens, "", prefix_offset, read_offset diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 8ea46f7db1681..c00b02fdbbbc0 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -51,6 +51,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): enable_lora=enable_lora, max_num_seqs=max_num_seqs, max_input_length=max_input_length, + **tokenizer_config, ) ray_tokenizer_group_cls = ray.remote( diff --git a/vllm/utils.py b/vllm/utils.py index 442b7945d3209..b4ed42f7148ec 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,10 +7,10 @@ import socket import subprocess import uuid import warnings -from collections import OrderedDict +from collections import OrderedDict, defaultdict from functools import lru_cache, partial from platform import uname -from typing import (Any, Awaitable, Callable, Generic, Hashable, List, +from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, Tuple, TypeVar, Union) import psutil @@ -26,7 +26,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, - "fp8_e5m2": torch.uint8, + "fp8": torch.uint8, } @@ -118,6 +118,15 @@ def is_hip() -> bool: return torch.version.hip is not None +@lru_cache(maxsize=None) +def is_cpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "cpu" in version("vllm") + except PackageNotFoundError: + return False + + @lru_cache(maxsize=None) def is_neuron() -> bool: try: @@ -263,7 +272,7 @@ def get_nvcc_cuda_version() -> Optional[Version]: return nvcc_cuda_version -def _generate_random_fp8_e5m2( +def _generate_random_fp8( tensor: torch.tensor, low: float, high: float, @@ -279,7 +288,7 @@ def _generate_random_fp8_e5m2( from vllm._C import cache_ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - cache_ops.convert_fp8_e5m2(tensor_tmp, tensor) + cache_ops.convert_fp8(tensor_tmp, tensor) del tensor_tmp @@ -308,7 +317,7 @@ def create_kv_caches_with_random( raise ValueError(f"Invalid model dtype: {model_dtype}") elif cache_dtype in ["half", "bfloat16", "float"]: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - elif cache_dtype == "fp8_e5m2": + elif cache_dtype == "fp8": torch_dtype = torch.uint8 else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") @@ -325,10 +334,10 @@ def create_kv_caches_with_random( key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) - if cache_dtype == 'fp8_e5m2': - _generate_random_fp8_e5m2(key_cache, -scale, scale) - elif torch_dtype in [torch.half, torch.bfloat16, torch.float]: + if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_cache, -scale, scale) else: raise ValueError( f"Does not support key cache of type {cache_dtype}") @@ -340,10 +349,10 @@ def create_kv_caches_with_random( value_cache = torch.empty(size=value_cache_shape, dtype=torch_dtype, device=device) - if cache_dtype == 'fp8_e5m2': - _generate_random_fp8_e5m2(value_cache, -scale, scale) - elif torch_dtype in [torch.half, torch.bfloat16, torch.float]: + if cache_dtype in ["auto", "half", "bfloat16", "float"]: value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(value_cache, -scale, scale) else: raise ValueError( f"Does not support value cache of type {cache_dtype}") @@ -368,6 +377,9 @@ def is_pin_memory_available() -> bool: elif is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False + elif is_cpu(): + print_warning_once("Pin memory is not supported on CPU.") + return False return True @@ -449,3 +461,20 @@ def maybe_expand_dim(tensor: torch.Tensor, def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() + + +def merge_dicts(dict1: Dict[Any, List[Any]], + dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: + """Merge 2 dicts that have key -> List of items. + + When a key conflicts, the values in dict1 is prioritized. + """ + merged_dict = defaultdict(list) + + for key, value in dict1.items(): + merged_dict[key].extend(value) + + for key, value in dict2.items(): + merged_dict[key].extend(value) + + return dict(merged_dict) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index bdc758cb8f03f..cd1bf4d023d9c 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -83,8 +83,7 @@ class CacheEngine: @staticmethod def get_cache_block_size( - block_size: int, - cache_dtype: str, + cache_config: CacheConfig, model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: @@ -92,12 +91,12 @@ class CacheEngine: num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) - key_cache_block = block_size * num_heads * head_size + key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) - if cache_dtype == "auto": + if cache_config.cache_dtype == "auto": dtype = model_config.dtype else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] dtype_size = get_dtype_size(dtype) return dtype_size * total diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py new file mode 100644 index 0000000000000..42f0828b826e2 --- /dev/null +++ b/vllm/worker/cpu_worker.py @@ -0,0 +1,317 @@ +"""A CPU worker class.""" +from typing import Dict, List, Optional + +import torch +import torch.distributed + +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized, init_distributed_environment) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.worker.model_runner import ModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + +logger = init_logger(__name__) + + +class CPUModelRunner(ModelRunner): + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + +class CPUCacheEngine: + """Manages the KV cache for CPU backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: + assert device_config.device_type == "cpu" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for CPU backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Get attention backend. + self.attn_backend = get_attn_backend(model_config.dtype) + + # Initialize the cache. + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[torch.Tensor]: + """Allocates KV cache on CPU.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_heads, self.head_size) + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + kv_cache.append( + torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) + return kv_cache + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: str, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = torch.tensor([], dtype=dtype).element_size() + return dtype_size * total + + +class CPUWorker(LoraNotSupportedWorkerBase): + """A worker class that executes (a partition of) the model on a CPU socket. + + Each worker is associated with a single CPU socket. The worker is + responsible for maintaining the KV cache and executing the model on the + CPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.model_runner = CPUModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine = None + self.cpu_cache = None + + def init_device(self) -> None: + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured CPU + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. + """ + # For CPU device, the block number will be calculated based on the + # cpu_kvcache_space. + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // + cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid. + """ + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + def _init_cache_engine(self) -> None: + self.cache_engine = CPUCacheEngine(self.cache_config, + self.model_config, + self.parallel_config, + self.device_config) + self.cpu_cache = self.cache_engine.cpu_cache + self.model_runner.block_size = self.cache_engine.block_size + + assert self.cpu_cache is not None + + # Populate the cache to warmup the memory + for layer_cache in self.cpu_cache: + layer_cache.fill_(0) + + def cache_copy( + self, + blocks_to_copy: Dict[int, List[int]], + ) -> None: + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + assert len(blocks_to_swap_in) == 0 + assert len(blocks_to_swap_out) == 0 + data = { + "num_seq_groups": num_seq_groups, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_copy(blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return {} + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.cpu_cache) + return output + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block. + """ + return CPUCacheEngine.get_cache_block_size( + self.cache_config.block_size, self.cache_config.cache_dtype, + self.model_config, self.parallel_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 31fa52476af1d..e7f20475ab1a7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, +from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) @@ -120,6 +120,26 @@ class ModelRunner: self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) + if self.kv_cache_dtype == "fp8" and is_hip(): + # Currently scaled KV cache is only enabled on ROCm + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + self.model.load_kv_cache_scales( + self.model_config.quantization_param_path) + else: + raise RuntimeError("Using FP8 KV cache and scaling " + "factors provided but model " + f"{self.model.__class__} does not " + "support loading scaling factors.") + else: + logger.warn("Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") + elif self.model_config.quantization_param_path is not None: + logger.warn("KV cache scaling factors provided, " + "but the KV cache data type is not FP8. " + "KV cache scaling factors will not be used.") + def set_block_size(self, block_size: int) -> None: self.block_size = block_size @@ -202,7 +222,6 @@ class ModelRunner: # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.extend(list(range(computed_len, prefill_end))) - lora_id = seq_group_metadata.lora_int_id if lora_id > 0: diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 0ae067aafb29b..6136d50d0c068 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -4,14 +4,15 @@ from typing import List, Optional import torch import torch.distributed -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.neuron_model_runner import NeuronModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase -class NeuronWorker: +class NeuronWorker(LoraNotSupportedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ @@ -21,11 +22,13 @@ class NeuronWorker: parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.cache_config = cache_config self.model_runner = NeuronModelRunner(model_config, parallel_config, scheduler_config, device_config) @@ -37,6 +40,35 @@ class NeuronWorker: def load_model(self): self.model_runner.load_model() + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks. + + Swapping is not yet supported, so always return num_cpu_blocks=0. + + We configure num_gpu_blocks to be equal to max_num_seqs. + """ + # Set the number of GPU blocks to be the same as the maximum number of + # sequences that can be processed in a single batch. This is equivalent + # to schedule without PagedAttention. + num_gpu_blocks = self.scheduler_config.max_num_seqs + + # Swap not yet supported with Neuron backend. + num_cpu_blocks = 0 + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. + """ + + # Different values are not tested. + assert num_cpu_blocks == 0 + assert num_gpu_blocks == self.scheduler_config.max_num_seqs + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + @torch.inference_mode() def execute_model( self, @@ -50,3 +82,10 @@ class NeuronWorker: output = self.model_runner.execute_model(seq_group_metadata_list) return output + + def get_cache_block_size_bytes(self) -> int: + """Determine the size in bytes of a cache block. + + This is required for speculative decoding; it is not yet implemented. + """ + raise NotImplementedError diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 48facb57de190..19de33089b2db 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -15,13 +15,14 @@ from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized) + ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner +from vllm.worker.worker_base import WorkerBase -class Worker: +class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. Each worker is associated with a single GPU. The worker is responsible for @@ -35,18 +36,19 @@ class Worker: parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, local_rank: int, rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, - kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -66,12 +68,11 @@ class Worker: scheduler_config, device_config, lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, + kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) # Uninitialized cache engine. Will be initialized by - # self.init_cache_engine(). - self.cache_config = None + # initialize_cache. self.cache_engine = None self.gpu_cache = None @@ -97,9 +98,9 @@ class Worker: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) @@ -107,20 +108,17 @@ class Worker: self.model_runner.load_model() @torch.inference_mode() - def profile_num_available_blocks( - self, - block_size: int, - gpu_memory_utilization: float, - cpu_swap_space: int, - cache_dtype: str, - ) -> Tuple[int, int]: - """Profiles the peak memory usage of the model and returns the maximum - number of GPU and CPU cache blocks that can be allocated. + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. - Args: - block_size: The size of the cache block. - gpu_memory_utilization: The fraction of the total GPU memory to use. - cpu_swap_space: The size of the CPU swap space in bytes. + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. @@ -141,12 +139,12 @@ class Worker: "Error in memory profiling. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") - cache_block_size = self.get_cache_block_size_bytes( - block_size, cache_dtype) + cache_block_size = self.get_cache_block_size_bytes() num_gpu_blocks = int( - (total_gpu_memory * gpu_memory_utilization - peak_memory) // - cache_block_size) - num_cpu_blocks = int(cpu_swap_space // cache_block_size) + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) if self.model_runner.lora_manager: @@ -155,14 +153,30 @@ class Worker: torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks - def init_cache_engine(self, cache_config: CacheConfig) -> None: - self.cache_config = cache_config + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks. + + This also warms up the model, which may record CUDA graphs. + """ + raise_if_cache_size_invalid(num_gpu_blocks, + self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._init_cache_engine() + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) - def warm_up_model(self) -> None: + def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by @@ -239,40 +253,23 @@ class Worker: def vocab_size(self) -> int: return self.model_runner.vocab_size - def get_cache_block_size_bytes(self, block_size: int, - cache_dtype: str) -> int: + def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. """ - return CacheEngine.get_cache_block_size(block_size, cache_dtype, + return CacheEngine.get_cache_block_size(self.cache_config, self.model_config, self.parallel_config) -def init_distributed_environment( +def init_worker_distributed_environment( parallel_config: ParallelConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch world " - "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - torch.distributed.init_process_group( - backend="nccl", - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) if pynccl_utils.is_initialized(): pynccl_world_size = pynccl_utils.get_world_size() @@ -291,10 +288,6 @@ def init_distributed_environment( init_method=distributed_init_method, ) - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) @@ -302,6 +295,11 @@ def init_distributed_environment( if not parallel_config.disable_custom_all_reduce: init_custom_ar() + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + if pynccl_utils.is_initialized(): + pynccl_utils.all_reduce(torch.zeros(1).cuda()) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. @@ -315,3 +313,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): f"{compute_capability[0]}.{compute_capability[1]}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.") + + +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, + max_model_len) -> None: + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_gpu_blocks + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py new file mode 100644 index 0000000000000..e3027c406ffeb --- /dev/null +++ b/vllm/worker/worker_base.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from typing import Dict, List + +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput, SequenceGroupMetadata + + +class WorkerBase(ABC): + """Worker interface that allows vLLM to cleanly separate implementations for + different hardware. + """ + + @abstractmethod + def init_device(self) -> None: + """Initialize device state, such as loading the model or other on-device + memory allocations. + """ + raise NotImplementedError + + @abstractmethod + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + + @abstractmethod + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + @abstractmethod + def get_cache_block_size_bytes() -> int: + """Return the size of a single cache block, in bytes. Used in + speculative decoding. + """ + raise NotImplementedError + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def list_loras(self) -> List[int]: + raise NotImplementedError + + +class LoraNotSupportedWorkerBase(WorkerBase): + """Partial implementation of WorkerBase that raises exceptions when LoRA + methods are invoked. + """ + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def remove_lora(self, lora_id: int) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def list_loras(self) -> List[int]: + raise ValueError(f"{type(self)} does not support LoRA")