mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 05:27:04 +08:00
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
commit
9784a0c414
@ -132,7 +132,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
|
||||
@ -78,17 +78,13 @@ HF_MOUNT="/root/.cache/huggingface"
|
||||
commands=$@
|
||||
echo "Commands:$commands"
|
||||
|
||||
if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
|
||||
fi
|
||||
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"pytest -v -s basic_correctness/test_basic_correctness.py"}
|
||||
|
||||
if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
|
||||
fi
|
||||
|
||||
if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
|
||||
fi
|
||||
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"pytest -v -s compile/test_basic_correctness.py"}
|
||||
|
||||
if [[ $commands == *"pytest -v -s lora"* ]]; then
|
||||
commands=${commands//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"}
|
||||
|
||||
@ -49,6 +49,7 @@ function cpu_tests() {
|
||||
# Run kernel tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
|
||||
pytest -x -v -s tests/kernels/test_onednn.py"
|
||||
|
||||
# Run basic model test
|
||||
@ -76,7 +77,7 @@ function cpu_tests() {
|
||||
# Run AWQ test
|
||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# set -e
|
||||
# VLLM_USE_V1=0 pytest -x -s -v \
|
||||
# pytest -x -s -v \
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run multi-lora tests
|
||||
@ -116,4 +117,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 2.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -348,6 +348,7 @@ steps:
|
||||
- pytest -v -s -m 'not cpu_test' v1/metrics
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
- pytest -v -s v1/test_request.py
|
||||
- pytest -v -s v1/test_outputs.py
|
||||
# Integration test for streaming correctness (requires special branch).
|
||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
# and $$BUILDKITE_PARALLEL_JOB_COUNT environment variables.
|
||||
# working_dir(str): specify the place where the command should execute, default to /vllm-workspace/tests
|
||||
# source_file_dependencies(list): the list of prefixes to opt-in the test for, if empty, the test will always run.
|
||||
# autorun_on_main (bool): default to false, if true, the test will run automatically when commit is pushed to main branch.
|
||||
|
||||
# When adding a test
|
||||
# - If the test belongs to an existing group, add it there
|
||||
@ -56,7 +57,7 @@ steps:
|
||||
- pytest -v -s -m 'not cpu_test' multimodal
|
||||
- pytest -v -s utils_
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test (CPU) # 4 mins
|
||||
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins
|
||||
timeout_in_minutes: 10
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -65,6 +66,7 @@ steps:
|
||||
- tests/multimodal
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
- tests/transformers_utils
|
||||
- tests/config
|
||||
no_gpu: true
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
@ -72,6 +74,7 @@ steps:
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s -m 'cpu_test' multimodal
|
||||
- pytest -v -s transformers_utils
|
||||
- pytest -v -s config
|
||||
|
||||
- label: Python-only Installation Test # 10min
|
||||
timeout_in_minutes: 20
|
||||
@ -329,6 +332,7 @@ steps:
|
||||
- pytest -v -s -m 'not cpu_test' v1/metrics
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
- pytest -v -s v1/test_request.py
|
||||
- pytest -v -s v1/test_outputs.py
|
||||
# Integration test for streaming correctness (requires special branch).
|
||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
@ -441,6 +445,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_config.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_fusion_attn.py
|
||||
@ -450,6 +455,7 @@ steps:
|
||||
- pytest -v -s compile/test_decorator.py
|
||||
- pytest -v -s compile/test_noop_elimination.py
|
||||
- pytest -v -s compile/test_aot_compile.py
|
||||
- pytest -v -s compile/test_qk_norm_rope_fusion.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
@ -604,6 +610,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
|
||||
@ -867,12 +874,12 @@ steps:
|
||||
optional: true
|
||||
commands:
|
||||
- pip install --upgrade git+https://github.com/huggingface/transformers
|
||||
- pytest -v -s tests/models/test_initialization.py
|
||||
- pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)'
|
||||
- pytest -v -s tests/models/test_transformers.py
|
||||
- pytest -v -s tests/models/multimodal/processing/
|
||||
- pytest -v -s tests/models/multimodal/test_mapping.py
|
||||
# - pytest -v -s tests/models/multimodal/processing/
|
||||
- pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)'
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||
# - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||
# Whisper needs spawn method to avoid deadlock
|
||||
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
|
||||
|
||||
@ -890,11 +897,16 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/v1/attention/backends/mla/cutlass_mla.py
|
||||
- vllm/v1/attention/backends/mla/flashinfer_mla.py
|
||||
- vllm/platforms/cuda.py
|
||||
- vllm/attention/selector.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
# Attention
|
||||
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
|
||||
- pytest -v -s tests/kernels/attention/test_attention_selector.py
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
||||
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
|
||||
@ -932,7 +944,7 @@ steps:
|
||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
||||
# Wrap with quotes to escape yaml
|
||||
# Wrap with quotes to escape yaml
|
||||
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
|
||||
|
||||
- label: Blackwell Fusion E2E Tests # 30 min
|
||||
|
||||
10
.github/CODEOWNERS
vendored
10
.github/CODEOWNERS
vendored
@ -61,6 +61,16 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/model_executor/models/transformers @hmellor
|
||||
/tests/models/test_transformers.py @hmellor
|
||||
|
||||
# Observability
|
||||
/vllm/config/observability.py @markmc
|
||||
/vllm/v1/metrics @markmc
|
||||
/tests/v1/metrics @markmc
|
||||
/vllm/tracing.py @markmc
|
||||
/tests/v1/tracing/test_tracing.py @markmc
|
||||
/vllm/config/kv_events.py @markmc
|
||||
/vllm/distributed/kv_events.py @markmc
|
||||
/tests/distributed/test_events.py @markmc
|
||||
|
||||
# Docs
|
||||
/docs/mkdocs @hmellor
|
||||
/docs/**/*.yml @hmellor
|
||||
|
||||
17
.github/mergify.yml
vendored
17
.github/mergify.yml
vendored
@ -151,6 +151,23 @@ pull_request_rules:
|
||||
add:
|
||||
- gpt-oss
|
||||
|
||||
- name: label-nvidia
|
||||
description: Automatically apply nvidia label
|
||||
conditions:
|
||||
- label != stale
|
||||
- or:
|
||||
- files~=cuda
|
||||
- files~=cutlass
|
||||
- files~=flashinfer
|
||||
- files~=trtllm
|
||||
- title~=(?i)NVIDIA
|
||||
- title~=(?i)CUDA
|
||||
- title~=(?i)CUTLASS
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- nvidia
|
||||
|
||||
- name: label-rocm
|
||||
description: Automatically apply rocm label
|
||||
conditions:
|
||||
|
||||
@ -39,6 +39,13 @@ set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13")
|
||||
# Supported AMD GPU architectures.
|
||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151")
|
||||
|
||||
# ROCm installation prefix. Default to /opt/rocm but allow override via
|
||||
# -DROCM_PATH=/your/rocm/path when invoking cmake.
|
||||
if(NOT DEFINED ROCM_PATH)
|
||||
set(ROCM_PATH "/opt/rocm" CACHE PATH "ROCm installation prefix")
|
||||
else()
|
||||
set(ROCM_PATH ${ROCM_PATH} CACHE PATH "ROCm installation prefix" FORCE)
|
||||
endif()
|
||||
#
|
||||
# Supported/expected torch versions for CUDA/ROCm.
|
||||
#
|
||||
@ -237,10 +244,27 @@ set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_CUMEM_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling cumem allocator extension.")
|
||||
# link against cuda driver library
|
||||
list(APPEND CUMEM_LIBS CUDA::cuda_driver)
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# link against cuda driver library
|
||||
list(APPEND CUMEM_LIBS CUDA::cuda_driver)
|
||||
else()
|
||||
# link against rocm driver library. Prefer an absolute path to
|
||||
# libamdhip64.so inside ${ROCM_PATH}/lib if available, otherwise fall
|
||||
# back to linking by name "amdhip64".
|
||||
find_library(AMDHIP64_LIB
|
||||
NAMES amdhip64 libamdhip64.so
|
||||
PATHS ${ROCM_PATH}/lib
|
||||
NO_DEFAULT_PATH)
|
||||
if(AMDHIP64_LIB)
|
||||
message(STATUS "Found libamdhip64 at ${AMDHIP64_LIB}")
|
||||
list(APPEND CUMEM_LIBS ${AMDHIP64_LIB})
|
||||
else()
|
||||
message(WARNING "libamdhip64 not found in ${ROCM_PATH}/lib; falling back to linking 'amdhip64' by name")
|
||||
list(APPEND CUMEM_LIBS amdhip64)
|
||||
endif()
|
||||
endif()
|
||||
define_extension_target(
|
||||
cumem_allocator
|
||||
DESTINATION vllm
|
||||
@ -265,6 +289,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/fused_qknorm_rope_kernel.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/sampler.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
@ -330,7 +355,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
# 9.0 for latest bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_ARCHS)
|
||||
|
||||
#
|
||||
@ -914,7 +939,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
# 9.0 for latest bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
|
||||
#
|
||||
|
||||
@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI)
|
||||
- [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link).
|
||||
- [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6).
|
||||
- [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing).
|
||||
|
||||
@ -1,10 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
# Disable DeepGEMM for this benchmark to use CUTLASS
|
||||
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear,
|
||||
W8A8BlockFp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
# Create random FP8 tensors
|
||||
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
|
||||
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||
|
||||
# Create quantized weight tensor
|
||||
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Create scales
|
||||
# Create weight scales
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
# SM90 CUTLASS requires row-major format for scales
|
||||
if use_cutlass and current_platform.is_device_capability(90):
|
||||
Bs = Bs.T.contiguous()
|
||||
# Create W8A8BlockFp8LinearOp instance
|
||||
weight_group_shape = GroupShape(block_n, block_k)
|
||||
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||
|
||||
linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=weight_group_shape,
|
||||
act_quant_group_shape=act_quant_group_shape,
|
||||
cutlass_block_fp8_supported=use_cutlass,
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
|
||||
def run():
|
||||
if use_cutlass:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
|
||||
)
|
||||
else:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
|
||||
)
|
||||
return linear_op.apply(
|
||||
input=A_ref,
|
||||
weight=B,
|
||||
weight_scale=Bs,
|
||||
input_scale=None,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from bench_utils import (
|
||||
Color,
|
||||
logger,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
# Conversation ID is a string (e.g: "UzTK34D")
|
||||
@ -417,6 +418,10 @@ def generate_conversations(
|
||||
data = file.read()
|
||||
tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
|
||||
list_of_tokens.extend(tokens_in_file)
|
||||
logger.info(
|
||||
f"Loaded {len(tokens_in_file)} tokens from file {filename}, "
|
||||
f"total tokens so far: {len(list_of_tokens)}"
|
||||
)
|
||||
|
||||
conversations: ConversationsMap = {}
|
||||
conv_id = 0
|
||||
@ -449,18 +454,25 @@ def generate_conversations(
|
||||
)
|
||||
base_offset += common_prefix_tokens
|
||||
|
||||
for conv_id in range(args.num_conversations):
|
||||
for conv_id in tqdm(
|
||||
range(args.num_conversations),
|
||||
total=args.num_conversations,
|
||||
desc="Generating conversations",
|
||||
unit="conv",
|
||||
):
|
||||
# Generate a single conversation
|
||||
messages: MessagesList = []
|
||||
|
||||
nturns = turn_count[conv_id]
|
||||
|
||||
# User prompt token count per turn (with lower limit)
|
||||
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns)
|
||||
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int)
|
||||
input_token_count = np.maximum(input_token_count, base_prompt_token_count)
|
||||
|
||||
# Assistant answer token count per turn (with lower limit)
|
||||
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns)
|
||||
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype(
|
||||
int
|
||||
)
|
||||
output_token_count = np.maximum(output_token_count, 1)
|
||||
|
||||
user_turn = True
|
||||
|
||||
@ -55,6 +55,7 @@ class ClientArgs(NamedTuple):
|
||||
verify_output: bool
|
||||
conversation_sampling: ConversationSampling
|
||||
request_rate: float
|
||||
max_retries: int
|
||||
|
||||
|
||||
class RequestArgs(NamedTuple):
|
||||
@ -63,6 +64,7 @@ class RequestArgs(NamedTuple):
|
||||
stream: bool
|
||||
limit_min_tokens: int # Use negative value for no limit
|
||||
limit_max_tokens: int # Use negative value for no limit
|
||||
timeout_sec: int
|
||||
|
||||
|
||||
class BenchmarkArgs(NamedTuple):
|
||||
@ -214,6 +216,7 @@ async def send_request(
|
||||
stream: bool = True,
|
||||
min_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout_sec: int = 120,
|
||||
) -> ServerResponse:
|
||||
payload = {
|
||||
"model": model,
|
||||
@ -235,10 +238,16 @@ async def send_request(
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# Calculate the timeout for the request
|
||||
timeout_sec = 120
|
||||
if max_tokens is not None:
|
||||
# Assume TPOT of 200ms and use max_tokens to determine timeout
|
||||
timeout_sec = max(timeout_sec, int(max_tokens * 0.2))
|
||||
token_based_timeout = int(max_tokens * 0.2)
|
||||
if token_based_timeout > timeout_sec:
|
||||
timeout_sec = token_based_timeout
|
||||
logger.info(
|
||||
"Using timeout of %ds based on max_tokens %d",
|
||||
timeout_sec,
|
||||
max_tokens,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_sec)
|
||||
|
||||
valid_response = True
|
||||
@ -409,6 +418,7 @@ async def send_turn(
|
||||
req_args.stream,
|
||||
min_tokens,
|
||||
max_tokens,
|
||||
req_args.timeout_sec,
|
||||
)
|
||||
|
||||
if response.valid is False:
|
||||
@ -518,6 +528,25 @@ async def poisson_sleep(request_rate: float, verbose: bool = False) -> None:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def exponential_backoff_sleep(
|
||||
attempt_cnt: int,
|
||||
base_rate: float = 1.0,
|
||||
backoff_factor: float = 2.0,
|
||||
jitter_fraction: float = 0.10,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
# Sleep with exponential backoff and jitter after a failed request.
|
||||
backoff_delay = base_rate * (backoff_factor**attempt_cnt)
|
||||
jittered_delay = backoff_delay * (
|
||||
1 + np.random.uniform(-jitter_fraction, jitter_fraction)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"Backoff for {jittered_delay:.3f} seconds...")
|
||||
|
||||
await asyncio.sleep(jittered_delay)
|
||||
|
||||
|
||||
async def client_main(
|
||||
args: ClientArgs,
|
||||
req_args: RequestArgs,
|
||||
@ -646,49 +675,62 @@ async def client_main(
|
||||
)
|
||||
time_of_last_turn[conv_id] = curr_time_sec
|
||||
|
||||
success = True
|
||||
try:
|
||||
result = await send_turn(
|
||||
session,
|
||||
client_id,
|
||||
conv_id,
|
||||
messages,
|
||||
current_turn,
|
||||
tokenizer,
|
||||
req_args,
|
||||
args.print_content,
|
||||
args.verify_output,
|
||||
)
|
||||
if result is not None:
|
||||
result_queue.put(result)
|
||||
else:
|
||||
# None means that the request failed,
|
||||
# and should not be added to the statistics.
|
||||
success = False
|
||||
num_failures += 1
|
||||
|
||||
logger.warning(
|
||||
f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
|
||||
success = False
|
||||
for attempt_cnt in range(args.max_retries + 1):
|
||||
try:
|
||||
exception = False
|
||||
result = await send_turn(
|
||||
session,
|
||||
client_id,
|
||||
conv_id,
|
||||
messages,
|
||||
current_turn,
|
||||
tokenizer,
|
||||
req_args,
|
||||
args.print_content,
|
||||
args.verify_output,
|
||||
)
|
||||
if result is not None:
|
||||
result_queue.put(result)
|
||||
success = True
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
|
||||
)
|
||||
except asyncio.exceptions.TimeoutError:
|
||||
exception = True
|
||||
logger.error(
|
||||
"%sClient %d - Timeout during conversation ID %s (turn: %d). "
|
||||
"Base timeout is %ss (set with --request-timeout-sec), but the "
|
||||
"effective timeout may be longer based on max_tokens. If this "
|
||||
"is unexpected, consider increasing the timeout or checking "
|
||||
"model performance.%s",
|
||||
Color.RED,
|
||||
client_id,
|
||||
conv_id,
|
||||
current_turn,
|
||||
req_args.timeout_sec,
|
||||
Color.RESET,
|
||||
)
|
||||
except Exception:
|
||||
exception = True
|
||||
logger.exception(
|
||||
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
|
||||
)
|
||||
|
||||
# Remove the conversation (should not be used again)
|
||||
active_convs.pop(conv_id)
|
||||
# Sleep before retry if not last attempt
|
||||
if not success and attempt_cnt < args.max_retries:
|
||||
await exponential_backoff_sleep(attempt_cnt, verbose=args.verbose)
|
||||
|
||||
except asyncio.exceptions.TimeoutError:
|
||||
if not success:
|
||||
num_failures += 1
|
||||
logger.exception(
|
||||
f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
|
||||
)
|
||||
break # Exit gracefully instead of raising an error
|
||||
# Remove the conversation (should not be used again)
|
||||
active_convs.pop(conv_id)
|
||||
if exception:
|
||||
break # Exit gracefully instead of raising an error
|
||||
|
||||
except Exception:
|
||||
num_failures += 1
|
||||
logger.exception(
|
||||
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
|
||||
)
|
||||
break # Exit gracefully instead of raising an error
|
||||
|
||||
if success:
|
||||
else:
|
||||
num_successes += 1
|
||||
|
||||
# Update the turns counter to include the LLM response
|
||||
@ -803,6 +845,7 @@ def get_client_config(
|
||||
verify_output=args.verify_output,
|
||||
conversation_sampling=args.conversation_sampling,
|
||||
request_rate=args.request_rate,
|
||||
max_retries=args.max_retries,
|
||||
)
|
||||
|
||||
if args.limit_min_tokens > 0 or args.limit_max_tokens > 0:
|
||||
@ -815,6 +858,9 @@ def get_client_config(
|
||||
"Invalid min/max tokens limits (min should not be larger than max)"
|
||||
)
|
||||
|
||||
if args.request_timeout_sec <= 0:
|
||||
raise ValueError("Request timeout must be a positive number")
|
||||
|
||||
# Arguments for API requests
|
||||
chat_url = f"{args.url}/v1/chat/completions"
|
||||
model_name = args.served_model_name if args.served_model_name else args.model
|
||||
@ -825,6 +871,7 @@ def get_client_config(
|
||||
stream=not args.no_stream,
|
||||
limit_min_tokens=args.limit_min_tokens,
|
||||
limit_max_tokens=args.limit_max_tokens,
|
||||
timeout_sec=args.request_timeout_sec,
|
||||
)
|
||||
|
||||
return client_args, req_args
|
||||
@ -968,7 +1015,7 @@ async def main_mp(
|
||||
f"(is alive: {client.is_alive()}){Color.RESET}"
|
||||
)
|
||||
|
||||
client.join(timeout=120)
|
||||
client.join(timeout=req_args.timeout_sec + 1)
|
||||
|
||||
if client.is_alive():
|
||||
logger.warning(
|
||||
@ -1334,6 +1381,16 @@ async def main() -> None:
|
||||
help="Expected request rate (Poisson process) per client in requests/sec."
|
||||
"Set to 0 for no delay between requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=int(os.environ.get("MULTITURN_BENCH_MAX_RETRIES", "0")),
|
||||
help="Maximum number of retry attempts for timed-out requests. "
|
||||
"Default is 0 (no retries). "
|
||||
"Set to higher values to retry failed requests and maintain "
|
||||
"fair workload distribution. "
|
||||
"Can also be set via MULTITURN_BENCH_MAX_RETRIES environment variable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conversation-sampling",
|
||||
type=ConversationSampling,
|
||||
@ -1351,6 +1408,13 @@ async def main() -> None:
|
||||
action="store_true",
|
||||
help="Verify the LLM output (compare to the answers in the input JSON file)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-timeout-sec",
|
||||
type=int,
|
||||
default=120,
|
||||
help="Timeout in seconds for each API request (default: 120). "
|
||||
"Automatically increased if max tokens imply longer decoding.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
|
||||
@ -2,4 +2,5 @@ numpy>=1.24
|
||||
pandas>=2.0.0
|
||||
aiohttp>=3.10
|
||||
transformers>=4.46
|
||||
xlsxwriter>=3.2.1
|
||||
xlsxwriter>=3.2.1
|
||||
tqdm>=4.66
|
||||
|
||||
@ -15,6 +15,7 @@ endif()
|
||||
#
|
||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
||||
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
|
||||
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
|
||||
set(ENABLE_AMXBF16 ON)
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
|
||||
else()
|
||||
set(ENABLE_AMXBF16 OFF)
|
||||
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AMXBF16 OFF)
|
||||
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
|
||||
endif()
|
||||
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
@ -193,7 +210,30 @@ endif()
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
|
||||
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
|
||||
set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
|
||||
if(ASIMD_FOUND)
|
||||
# Set number of parallel build processes
|
||||
include(ProcessorCount)
|
||||
ProcessorCount(NPROC)
|
||||
if(NOT NPROC)
|
||||
set(NPROC 4)
|
||||
endif()
|
||||
# locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0)
|
||||
# and create a local shim dir with it
|
||||
vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR)
|
||||
|
||||
find_library(OPEN_MP
|
||||
NAMES gomp
|
||||
PATHS ${VLLM_TORCH_GOMP_SHIM_DIR}
|
||||
NO_DEFAULT_PATH
|
||||
REQUIRED
|
||||
)
|
||||
# Set LD_LIBRARY_PATH to include the shim dir at build time to use the same libgomp as PyTorch
|
||||
if (OPEN_MP)
|
||||
set(ENV{LD_LIBRARY_PATH} "${VLLM_TORCH_GOMP_SHIM_DIR}:$ENV{LD_LIBRARY_PATH}")
|
||||
endif()
|
||||
|
||||
# Fetch and populate ACL
|
||||
if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}")
|
||||
message(STATUS "Using ACL from specified source directory: $ENV{ACL_ROOT_DIR}")
|
||||
else()
|
||||
@ -207,38 +247,53 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
||||
GIT_PROGRESS TRUE
|
||||
)
|
||||
set(ENV{ACL_ROOT_DIR} "${arm_compute_SOURCE_DIR}")
|
||||
set(ACL_LIB_DIR "$ENV{ACL_ROOT_DIR}/build")
|
||||
endif()
|
||||
|
||||
# Build ACL with scons
|
||||
include(ProcessorCount)
|
||||
ProcessorCount(_NPROC)
|
||||
set(_scons_cmd
|
||||
scons -j${_NPROC}
|
||||
Werror=0 debug=0 neon=1 examples=0 embed_kernels=0 os=linux
|
||||
arch=armv8.2-a build=native benchmark_examples=0 fixed_format_kernels=1
|
||||
multi_isa=1 openmp=1 cppthreads=0
|
||||
# Build ACL with CMake
|
||||
set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
set(ARM_COMPUTE_ARCH "armv8.2-a")
|
||||
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
|
||||
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
|
||||
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
|
||||
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
|
||||
set(ARM_COMPUTE_BUILD_TESTING "OFF")
|
||||
|
||||
set(_cmake_config_cmd
|
||||
${CMAKE_COMMAND} -G Ninja -B build
|
||||
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DARM_COMPUTE_ARCH=armv8.2-a
|
||||
-DARM_COMPUTE_ENABLE_ASSERTS=OFF
|
||||
-DARM_COMPUTE_ENABLE_CPPTHREADS=OFF
|
||||
-DARM_COMPUTE_ENABLE_OPENMP=ON
|
||||
-DARM_COMPUTE_ENABLE_WERROR=OFF
|
||||
-DARM_COMPUTE_BUILD_EXAMPLES=OFF
|
||||
-DARM_COMPUTE_BUILD_TESTING=OFF)
|
||||
set(_cmake_build_cmd
|
||||
${CMAKE_COMMAND} --build build -- -j${NPROC}
|
||||
)
|
||||
|
||||
# locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0)
|
||||
# and create a local shim dir with it
|
||||
include("${CMAKE_CURRENT_LIST_DIR}/utils.cmake")
|
||||
vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR)
|
||||
|
||||
if(NOT VLLM_TORCH_GOMP_SHIM_DIR STREQUAL "")
|
||||
list(APPEND _scons_cmd extra_link_flags=-L${VLLM_TORCH_GOMP_SHIM_DIR})
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${_scons_cmd}
|
||||
COMMAND ${_cmake_config_cmd}
|
||||
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
|
||||
)
|
||||
execute_process(
|
||||
COMMAND ${_cmake_build_cmd}
|
||||
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
|
||||
RESULT_VARIABLE _acl_rc
|
||||
)
|
||||
|
||||
if(NOT _acl_rc EQUAL 0)
|
||||
message(FATAL_ERROR "ACL SCons build failed (exit ${_acl_rc}).")
|
||||
endif()
|
||||
message(STATUS "Arm Compute Library (ACL) built successfully.")
|
||||
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
# VLLM/oneDNN settings for ACL
|
||||
set(ONEDNN_AARCH64_USE_ACL ON CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(VLLM_USE_ACL)
|
||||
endif()
|
||||
|
||||
@ -275,7 +330,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
||||
set(ONEDNN_VERBOSE "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE})
|
||||
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
|
||||
target_include_directories(
|
||||
dnnl_ext
|
||||
@ -305,14 +363,14 @@ endif()
|
||||
#
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/activation.cpp"
|
||||
"csrc/cpu/attention.cpp"
|
||||
"csrc/cpu/cache.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp"
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
"csrc/cpu/scratchpad_manager.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
|
||||
@ -1,798 +0,0 @@
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using q_load_vec_type = void;
|
||||
using q_vec_type = void;
|
||||
using k_load_vec_type = void;
|
||||
using k_vec_type = void;
|
||||
using qk_acc_vec_type = void;
|
||||
using v_load_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using q_load_vec_type = vec_op::FP32Vec4;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power and s390x architecture-specific vector types
|
||||
using q_load_vec_type = vec_op::FP32Vec8;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures, including x86
|
||||
using q_load_vec_type = vec_op::FP16Vec8;
|
||||
using k_load_vec_type = vec_op::FP16Vec16;
|
||||
using v_load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::BF16Vec32;
|
||||
using k_load_vec_type = vec_op::BF16Vec32;
|
||||
using k_vec_type = vec_op::BF16Vec32;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#else
|
||||
#ifdef __aarch64__
|
||||
#ifndef ARM_BF16_SUPPORT
|
||||
// pass
|
||||
#else
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::BF16Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
#else
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::BF16Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
|
||||
const int capacity) {
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
max = max >= data[i] ? max : data[i];
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data[i] = std::exp(data[i] - max);
|
||||
sum += data[i];
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (; i < size; ++i) {
|
||||
data[i] /= sum;
|
||||
}
|
||||
|
||||
for (; i < capacity; ++i) {
|
||||
data[i] = 0;
|
||||
}
|
||||
|
||||
return {max, sum};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
|
||||
const int capacity,
|
||||
const float alibi_slope,
|
||||
const int start_index,
|
||||
const int seq_len) {
|
||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
|
||||
data[i] = qk;
|
||||
max = max >= qk ? max : qk;
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data[i] = std::exp(data[i] - max);
|
||||
sum += data[i];
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (; i < size; ++i) {
|
||||
data[i] /= sum;
|
||||
}
|
||||
|
||||
for (; i < capacity; ++i) {
|
||||
data[i] = 0;
|
||||
}
|
||||
|
||||
return {max, sum};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data,
|
||||
const int size) {
|
||||
T max = max_data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
max = max >= max_data[i] ? max : max_data[i];
|
||||
}
|
||||
|
||||
T rescaled_sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
T rescale_factor = std::exp(max_data[i] - max);
|
||||
rescaled_sum += rescale_factor * sum_data[i];
|
||||
sum_data[i] *= rescale_factor;
|
||||
}
|
||||
for (int i = 0; i < size; ++i) {
|
||||
sum_data[i] /= rescaled_sum + 1e-8;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
|
||||
struct reduceQKBlockKernel {
|
||||
using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
|
||||
using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
|
||||
using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
|
||||
using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
|
||||
using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
|
||||
|
||||
constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
|
||||
constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
|
||||
constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
|
||||
|
||||
static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
|
||||
static_assert(k_load_vec_type::get_elem_num() % x == 0);
|
||||
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
|
||||
|
||||
FORCE_INLINE static void call(const scalar_t* __restrict__ q,
|
||||
const scalar_t* __restrict__ k_block,
|
||||
float* __restrict__ logits, float scale,
|
||||
const int token_num) {
|
||||
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
|
||||
|
||||
qk_acc_vec_type group_accums[MAX_GROUP_NUM];
|
||||
if (token_num == BLOCK_SIZE) {
|
||||
for (int q_offset = 0; q_offset < HEAD_SIZE;
|
||||
q_offset += x, k_block += x * BLOCK_SIZE) {
|
||||
q_load_vec_type q_load_group_vec(q + q_offset);
|
||||
q_vec_type q_group_vec(q_load_group_vec);
|
||||
|
||||
vec_op::unroll_loop<int, MAX_GROUP_NUM>(
|
||||
[k_block, &q_group_vec, &group_accums](int token_group_idx) {
|
||||
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
|
||||
TOKEN_PER_GROUP);
|
||||
k_vec_type k_group_vec(k_load_group_vec);
|
||||
vec_op::fma(group_accums[token_group_idx], q_group_vec,
|
||||
k_group_vec);
|
||||
vec_op::prefetch(k_block + x * BLOCK_SIZE +
|
||||
token_group_idx * x * TOKEN_PER_GROUP);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
for (int q_offset = 0; q_offset < HEAD_SIZE;
|
||||
q_offset += x, k_block += x * BLOCK_SIZE) {
|
||||
q_load_vec_type q_load_group_vec(q + q_offset);
|
||||
q_vec_type q_group_vec(q_load_group_vec);
|
||||
for (int token_group_start = 0; token_group_start < group_num;
|
||||
token_group_start += UNROLL_GROUP_NUM) {
|
||||
vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
|
||||
[token_group_start, k_block, &q_group_vec,
|
||||
&group_accums](int token_group_idx) {
|
||||
token_group_idx += token_group_start;
|
||||
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
|
||||
TOKEN_PER_GROUP);
|
||||
k_vec_type k_group_vec(k_load_group_vec);
|
||||
vec_op::fma(group_accums[token_group_idx], q_group_vec,
|
||||
k_group_vec);
|
||||
vec_op::prefetch(k_block + x * BLOCK_SIZE +
|
||||
token_group_idx * x * TOKEN_PER_GROUP);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int token_group_idx = 0; token_group_idx < group_num;
|
||||
++token_group_idx) {
|
||||
vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
|
||||
[&group_accums, logits, scale, token_group_idx](int token_idx) {
|
||||
float dot_v =
|
||||
group_accums[token_group_idx]
|
||||
.template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
|
||||
TOKEN_PER_GROUP>(token_idx);
|
||||
logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
|
||||
dot_v * scale;
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||
int HEAD_PARTITION_SIZE, typename acc_t>
|
||||
FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
|
||||
acc_t&& acc) {
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
|
||||
static_assert(BLOCK_SIZE == ELEM_NUM);
|
||||
vec_op::FP32Vec16 prob_vec(prob);
|
||||
|
||||
vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
|
||||
v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
|
||||
vec_op::FP32Vec16 fp32_v_vec(v_vec);
|
||||
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
|
||||
});
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
// Paged attention v1
|
||||
namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||
struct paged_attention_v1_impl {
|
||||
static void call(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs,
|
||||
// max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads) {
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
|
||||
static_assert(BLOCK_SIZE == 16);
|
||||
|
||||
int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
|
||||
int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
|
||||
TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
|
||||
|
||||
const int parallel_work_item_num = omp_get_max_threads();
|
||||
|
||||
size_t logits_bytes =
|
||||
parallel_work_item_num * max_seq_len_padded * sizeof(float);
|
||||
float* logits = (float*)std::aligned_alloc(
|
||||
64, logits_bytes); // Cacheline alignment for each context token.
|
||||
// [parallel_work_item_num, max_seq_len_padded]
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
int seq_len = seq_lens[seq_idx];
|
||||
const int* seq_block_table =
|
||||
block_tables + max_num_blocks_per_seq * seq_idx;
|
||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t* __restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
float* __restrict__ thread_block_logits =
|
||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const scalar_t* __restrict__ k_block_cache_ptr =
|
||||
k_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
float* __restrict__ head_block_logits =
|
||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||
|
||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
|
||||
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
|
||||
}
|
||||
|
||||
// Compute softmax
|
||||
if (alibi_slopes) {
|
||||
reduceSoftmaxAlibi(thread_block_logits, seq_len,
|
||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||
seq_len);
|
||||
} else {
|
||||
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
// Compute value
|
||||
constexpr int head_elem_num_per_partition = 16;
|
||||
constexpr int head_partition_num =
|
||||
HEAD_SIZE / head_elem_num_per_partition;
|
||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||
++head_part_idx) {
|
||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||
scalar_t* __restrict__ out_ptr =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||
head_part_idx * head_elem_num_per_partition;
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const float* __restrict__ prob_vec_ptr =
|
||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||
const scalar_t* __restrict__ v_block_cache_ptr =
|
||||
v_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
|
||||
head_elem_num_per_partition>(
|
||||
prob_vec_ptr, v_block_cache_ptr, accums);
|
||||
|
||||
if (block_idx != block_num - 1) {
|
||||
const int64_t next_physical_block_idx =
|
||||
seq_block_table[block_idx + 1];
|
||||
const scalar_t* __restrict__ next_v_block_cache_ptr =
|
||||
v_cache + next_physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
if (head_elem_idx % 2 == 0) {
|
||||
vec_op::prefetch(next_v_block_cache_ptr +
|
||||
BLOCK_SIZE * head_elem_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
float value = accums[head_elem_idx].reduce_sum();
|
||||
vec_op::storeFP32(value, out_ptr + head_elem_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
std::free(logits);
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
|
||||
num_heads);
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
void paged_attention_v1_impl_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 32:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 192:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||
seq_lens, max_seq_len, alibi_slopes);
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V1_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||
CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
|
||||
CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
|
||||
});
|
||||
}
|
||||
|
||||
// Paged attention v2
|
||||
namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
||||
struct paged_attention_v2_impl {
|
||||
static void call(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs,
|
||||
// max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads, const int max_num_partitions) {
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
|
||||
static_assert(BLOCK_SIZE == 16);
|
||||
static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
|
||||
static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
|
||||
|
||||
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int partition_idx = 0; partition_idx < max_num_partitions;
|
||||
++partition_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||
|
||||
if (start_token_idx >= seq_len) continue;
|
||||
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
const bool no_reduce = (partition_num == 1);
|
||||
const int token_num =
|
||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||
start_token_idx);
|
||||
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_token_num =
|
||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||
const int* seq_block_table = block_tables +
|
||||
max_num_blocks_per_seq * seq_idx +
|
||||
start_token_idx / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t* __restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
|
||||
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
|
||||
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const scalar_t* __restrict__ k_block_cache_ptr =
|
||||
k_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
float* __restrict__ head_block_logits =
|
||||
logits + block_idx * BLOCK_SIZE;
|
||||
|
||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
|
||||
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
|
||||
}
|
||||
|
||||
std::pair<float, float> max_and_sum;
|
||||
if (alibi_slopes) {
|
||||
max_and_sum = reduceSoftmaxAlibi(
|
||||
logits, token_num, block_num * BLOCK_SIZE,
|
||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||
} else {
|
||||
max_and_sum =
|
||||
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
auto&& [max_logit, exp_sum] = max_and_sum;
|
||||
|
||||
scalar_t* __restrict__ output_buffer = nullptr;
|
||||
if (!no_reduce) {
|
||||
auto idx = seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions + partition_idx;
|
||||
max_logits[idx] = max_logit;
|
||||
exp_sums[idx] = exp_sum;
|
||||
output_buffer =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE +
|
||||
partition_idx * HEAD_SIZE;
|
||||
} else {
|
||||
output_buffer =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
}
|
||||
|
||||
// Compute value
|
||||
constexpr int head_elem_num_per_partition = 16;
|
||||
constexpr int head_partition_num =
|
||||
HEAD_SIZE / head_elem_num_per_partition;
|
||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||
++head_part_idx) {
|
||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||
scalar_t* __restrict__ out_ptr =
|
||||
output_buffer + head_part_idx * head_elem_num_per_partition;
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const float* __restrict__ prob_vec_ptr =
|
||||
logits + block_idx * BLOCK_SIZE;
|
||||
const scalar_t* __restrict__ v_block_cache_ptr =
|
||||
v_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
|
||||
head_elem_num_per_partition>(
|
||||
prob_vec_ptr, v_block_cache_ptr, accums);
|
||||
|
||||
if (block_idx != block_num - 1) {
|
||||
const int64_t next_physical_block_idx =
|
||||
seq_block_table[block_idx + 1];
|
||||
const scalar_t* __restrict__ next_v_block_cache_ptr =
|
||||
v_cache + next_physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
if (head_elem_idx % 2 == 0) {
|
||||
vec_op::prefetch(next_v_block_cache_ptr +
|
||||
BLOCK_SIZE * head_elem_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
float value = accums[head_elem_idx].reduce_sum();
|
||||
vec_op::storeFP32(value, out_ptr + head_elem_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rescale partition softmax and store the factors to exp_sums
|
||||
#pragma omp parallel for collapse(2) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1) continue;
|
||||
|
||||
reducePartitionSoftmax(
|
||||
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions,
|
||||
exp_sums + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions,
|
||||
partition_num);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce values
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||
constexpr int head_elem_num_per_group =
|
||||
16; // Note: didn't align with the cacheline size, due to some
|
||||
// HEAD_SIZE didn't align with 64 bytes
|
||||
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||
const float* __restrict__ rescale_factors = exp_sums;
|
||||
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1) continue;
|
||||
|
||||
const float* __restrict__ seq_head_rescale_factors =
|
||||
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions;
|
||||
const scalar_t* __restrict__ seq_head_tmp_out =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE +
|
||||
group_idx * head_elem_num_per_group;
|
||||
scalar_t* __restrict__ seq_head_output =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||
group_idx * head_elem_num_per_group;
|
||||
|
||||
vec_op::FP32Vec16 acc;
|
||||
for (int i = 0; i < partition_num; ++i) {
|
||||
vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
|
||||
v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
|
||||
vec_op::FP32Vec16 fp32_value(value);
|
||||
acc = acc + fp32_value * rescale_factor;
|
||||
}
|
||||
v_load_vec_type cast_acc(acc);
|
||||
cast_acc.save(seq_head_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
||||
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
||||
max_num_partitions);
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_impl_launcher(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const std::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
int max_num_partitions = exp_sums.size(-1);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 32:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 192:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
|
||||
alibi_slopes);
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V2_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
|
||||
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
|
||||
});
|
||||
}
|
||||
@ -1,214 +0,0 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
#if defined(__x86_64__)
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
|
||||
#else
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& mapping_pairs,
|
||||
const int element_num_per_block,
|
||||
const int layer_num) {
|
||||
const size_t pair_num = mapping_pairs.size(0);
|
||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int layer = 0; layer < layer_num; ++layer) {
|
||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||
int64_t source_offset =
|
||||
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||
int64_t target_offset =
|
||||
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||
scalar_t* source_ptr = key_cache_ptr + source_offset;
|
||||
scalar_t* target_ptr = key_cache_ptr + target_offset;
|
||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||
|
||||
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
||||
source_ptr = value_cache_ptr + source_offset;
|
||||
target_ptr = value_cache_ptr + target_offset;
|
||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void reshape_and_cache_cpu_impl(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int num_tokens,
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x) {
|
||||
const int block_elem_num = num_heads * head_size * block_size;
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx >= 0) {
|
||||
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
|
||||
int src_value_head_idx =
|
||||
token_idx * value_stride + head_idx * head_size;
|
||||
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
|
||||
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
|
||||
const int64_t block_index = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
scalar_t* target_key_head_ptr = key_cache +
|
||||
block_elem_num * block_index +
|
||||
head_idx * block_size * head_size;
|
||||
scalar_t* target_value_head_ptr = value_cache +
|
||||
block_elem_num * block_index +
|
||||
head_idx * block_size * head_size;
|
||||
|
||||
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
|
||||
const int64_t target_offset =
|
||||
src_key_idx * block_size + block_offset * x;
|
||||
for (int i = 0; i < x; ++i) {
|
||||
target_key_head_ptr[target_offset + i] =
|
||||
src_key_head_ptr[src_key_idx + i];
|
||||
}
|
||||
}
|
||||
|
||||
for (int src_value_idx = 0; src_value_idx < head_size;
|
||||
++src_value_idx) {
|
||||
const int64_t target_offset =
|
||||
src_value_idx * block_size + block_offset;
|
||||
target_value_head_ptr[target_offset] =
|
||||
src_value_head_ptr[src_value_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void concat_and_cache_mla_cpu_impl(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int num_tokens, //
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size //
|
||||
) {
|
||||
#pragma omp parallel for
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src,
|
||||
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
|
||||
int size, int offset) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
unsigned num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
|
||||
num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
TORCH_CHECK(kv_cache_dtype != "fp8");
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
|
||||
concat_and_cache_mla_cpu_impl<scalar_t>(
|
||||
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
|
||||
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
|
||||
kv_lora_rank, pe_dim, block_size);
|
||||
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping) {
|
||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||
}
|
||||
249
csrc/cpu/cpu_attn.cpp
Normal file
249
csrc/cpu/cpu_attn.cpp
Normal file
@ -0,0 +1,249 @@
|
||||
#include "cpu_attn_vec.hpp"
|
||||
#include "cpu_attn_vec16.hpp"
|
||||
|
||||
#ifdef CPU_CAPABILITY_AMXBF16
|
||||
#include "cpu_attn_amx.hpp"
|
||||
#define AMX_DISPATCH(...) \
|
||||
case cpu_attention::ISA::AMX: { \
|
||||
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::AMX, \
|
||||
scalar_t, head_dim>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
|
||||
#endif
|
||||
|
||||
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
|
||||
case HEAD_DIM: { \
|
||||
constexpr size_t head_dim = HEAD_DIM; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \
|
||||
[&] { \
|
||||
switch (HEAD_DIM) { \
|
||||
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \
|
||||
default: { \
|
||||
TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \
|
||||
std::to_string(HEAD_DIM)); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \
|
||||
[&] { \
|
||||
switch (ISA_TYPE) { \
|
||||
AMX_DISPATCH(__VA_ARGS__) \
|
||||
case cpu_attention::ISA::VEC: { \
|
||||
using attn_impl = \
|
||||
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
|
||||
head_dim>; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case cpu_attention::ISA::VEC16: { \
|
||||
using attn_impl = \
|
||||
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC16, scalar_t, \
|
||||
head_dim>; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: { \
|
||||
TORCH_CHECK(false, "Invalid CPU attention ISA type."); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
torch::Tensor get_scheduler_metadata(
|
||||
const int64_t num_req, const int64_t num_heads_q,
|
||||
const int64_t num_heads_kv, const int64_t head_dim,
|
||||
const torch::Tensor& seq_lens, at::ScalarType dtype,
|
||||
const torch::Tensor& query_start_loc, const bool casual,
|
||||
const int64_t window_size, const std::string& isa_hint,
|
||||
const bool enable_kv_split) {
|
||||
cpu_attention::ISA isa;
|
||||
if (isa_hint == "amx") {
|
||||
isa = cpu_attention::ISA::AMX;
|
||||
} else if (isa_hint == "vec") {
|
||||
isa = cpu_attention::ISA::VEC;
|
||||
} else if (isa_hint == "vec16") {
|
||||
isa = cpu_attention::ISA::VEC16;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
|
||||
}
|
||||
|
||||
cpu_attention::AttentionScheduler::ScheduleInput input;
|
||||
input.num_reqs = num_req;
|
||||
input.num_heads_q = num_heads_q;
|
||||
input.num_heads_kv = num_heads_kv;
|
||||
input.head_dim = head_dim;
|
||||
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
|
||||
input.seq_lens = seq_lens.data_ptr<int32_t>();
|
||||
if (window_size != -1) {
|
||||
input.left_sliding_window_size = window_size - 1;
|
||||
if (casual) {
|
||||
input.right_sliding_window_size = 0;
|
||||
} else {
|
||||
input.right_sliding_window_size = window_size - 1;
|
||||
}
|
||||
} else {
|
||||
input.left_sliding_window_size = -1;
|
||||
if (casual) {
|
||||
input.right_sliding_window_size = 0;
|
||||
} else {
|
||||
input.right_sliding_window_size = -1;
|
||||
}
|
||||
}
|
||||
input.casual = casual;
|
||||
input.isa = isa;
|
||||
input.enable_kv_split = enable_kv_split;
|
||||
TORCH_CHECK(casual, "Only supports casual mask for now.");
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
||||
CPU_ATTN_DISPATCH_IMPL(isa, [&]() {
|
||||
input.elem_size = sizeof(scalar_t);
|
||||
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
|
||||
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
|
||||
input.output_buffer_elem_size =
|
||||
sizeof(attn_impl::partial_output_buffer_t);
|
||||
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
|
||||
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
cpu_attention::AttentionScheduler scheduler;
|
||||
torch::Tensor metadata = scheduler.schedule(input);
|
||||
return metadata;
|
||||
}
|
||||
|
||||
void cpu_attn_reshape_and_cache(
|
||||
const torch::Tensor& key, // [token_num, head_num, head_size]
|
||||
const torch::Tensor& value, // [token_num, head_num, head_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const torch::Tensor& slot_mapping, const std::string& isa) {
|
||||
TORCH_CHECK_EQ(key.dim(), 3);
|
||||
TORCH_CHECK_EQ(value.dim(), 3);
|
||||
TORCH_CHECK_EQ(key_cache.dim(), 4);
|
||||
TORCH_CHECK_EQ(value_cache.dim(), 4);
|
||||
TORCH_CHECK_EQ(key.stride(2), 1);
|
||||
TORCH_CHECK_EQ(value.stride(2), 1);
|
||||
|
||||
const int64_t token_num = key.size(0);
|
||||
const int64_t key_token_num_stride = key.stride(0);
|
||||
const int64_t value_token_num_stride = value.stride(0);
|
||||
const int64_t head_num = value.size(1);
|
||||
const int64_t key_head_num_stride = key.stride(1);
|
||||
const int64_t value_head_num_stride = value.stride(1);
|
||||
const int64_t num_blocks = key_cache.size(0);
|
||||
const int64_t num_blocks_stride = key_cache.stride(0);
|
||||
const int64_t cache_head_num_stride = key_cache.stride(1);
|
||||
const int64_t block_size = key_cache.size(2);
|
||||
const int64_t block_size_stride = key_cache.stride(2);
|
||||
const int64_t head_dim = key.size(-1);
|
||||
|
||||
cpu_attention::ISA isa_tag = [&]() {
|
||||
if (isa == "amx") {
|
||||
return cpu_attention::ISA::AMX;
|
||||
} else if (isa == "vec") {
|
||||
return cpu_attention::ISA::VEC;
|
||||
} else if (isa == "vec16") {
|
||||
return cpu_attention::ISA::VEC16;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid ISA type: " + isa);
|
||||
}
|
||||
}();
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
||||
CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() {
|
||||
attn_impl::reshape_and_cache(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), token_num,
|
||||
key_token_num_stride, value_token_num_stride, head_num,
|
||||
key_head_num_stride, value_head_num_stride, num_blocks,
|
||||
num_blocks_stride, cache_head_num_stride, block_size,
|
||||
block_size_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void cpu_attention_with_kv_cache(
|
||||
const torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||
const torch::Tensor&
|
||||
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const torch::Tensor&
|
||||
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor& output, // [num_tokens, num_heads, head_size]
|
||||
const torch::Tensor& query_start_loc, // [num_tokens + 1]
|
||||
const torch::Tensor& seq_lens, // [num_tokens]
|
||||
const double scale, const bool causal,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, // [num_heads]
|
||||
const int64_t sliding_window_left, const int64_t sliding_window_right,
|
||||
const torch::Tensor& block_table, // [num_tokens, max_block_num]
|
||||
const double softcap, const torch::Tensor& scheduler_metadata,
|
||||
const std::optional<torch::Tensor>& s_aux // [num_heads]
|
||||
) {
|
||||
TORCH_CHECK_EQ(query.dim(), 3);
|
||||
TORCH_CHECK_EQ(query.stride(2), 1);
|
||||
TORCH_CHECK_EQ(key_cache.dim(), 4);
|
||||
TORCH_CHECK_EQ(value_cache.dim(), 4);
|
||||
|
||||
cpu_attention::AttentionInput input;
|
||||
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
|
||||
scheduler_metadata.data_ptr());
|
||||
input.num_tokens = query.size(0);
|
||||
input.num_heads = query.size(1);
|
||||
input.num_kv_heads = key_cache.size(1);
|
||||
input.block_size = key_cache.size(2);
|
||||
input.query = query.data_ptr();
|
||||
input.query_num_tokens_stride = query.stride(0);
|
||||
input.query_num_heads_stride = query.stride(1);
|
||||
input.cache_num_blocks_stride = key_cache.stride(0);
|
||||
input.cache_num_kv_heads_stride = key_cache.stride(1);
|
||||
input.blt_num_tokens_stride = block_table.stride(0);
|
||||
input.key_cache = key_cache.data_ptr();
|
||||
input.value_cache = value_cache.data_ptr();
|
||||
input.output = output.data_ptr();
|
||||
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
|
||||
input.seq_lens = seq_lens.data_ptr<int32_t>();
|
||||
input.block_table = block_table.data_ptr<int32_t>();
|
||||
input.alibi_slopes =
|
||||
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
|
||||
// For now sink must be bf16
|
||||
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
|
||||
input.scale = scale;
|
||||
input.causal = causal;
|
||||
input.sliding_window_left = sliding_window_left;
|
||||
input.sliding_window_right = sliding_window_right;
|
||||
if (input.causal) {
|
||||
// to make boundary calculation easier
|
||||
input.sliding_window_right = 0;
|
||||
}
|
||||
float softcap_fp32 = softcap;
|
||||
input.softcap = softcap_fp32;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] {
|
||||
CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() {
|
||||
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
|
||||
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
|
||||
mainloop(&input);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
511
csrc/cpu/cpu_attn_amx.hpp
Normal file
511
csrc/cpu/cpu_attn_amx.hpp
Normal file
@ -0,0 +1,511 @@
|
||||
#ifndef CPU_ATTN_AMX_HPP
|
||||
#define CPU_ATTN_AMX_HPP
|
||||
|
||||
#include "cpu_attn_impl.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
namespace {
|
||||
// AMX specific
|
||||
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
|
||||
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
|
||||
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
|
||||
|
||||
typedef struct __tile_config {
|
||||
uint8_t palette_id = 1;
|
||||
uint8_t start_row = 0;
|
||||
uint8_t reserved_0[14] = {0};
|
||||
uint16_t colsb[16] = {0};
|
||||
uint8_t rows[16] = {0};
|
||||
} __tilecfg;
|
||||
|
||||
// 2-2-4 pattern, for 16 < m <= 32
|
||||
// TILE 0, 1: load A matrix, row num should be 16, m - 16
|
||||
// TILE 2, 3: load B matrix, row num should be 16
|
||||
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
|
||||
// - 16
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm224 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
|
||||
void* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm224<c10::BFloat16> {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
c10::BFloat16* __restrict__ a_tile,
|
||||
c10::BFloat16* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
const int32_t k_times =
|
||||
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
|
||||
c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM;
|
||||
const int64_t a_tile_stride = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// q_buffer is prepacked
|
||||
return AMX_TILE_ROW_BYTES;
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// logits_buffer is row-major
|
||||
return lda * sizeof(c10::BFloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// k_cache is prepacked
|
||||
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// v_cache is prepacked
|
||||
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
// k_cache, v_cache are prepacked
|
||||
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
// logits_buffer, output_buffer are not prepacked
|
||||
float* __restrict__ c_tile_4 = c_tile;
|
||||
float* __restrict__ c_tile_5 =
|
||||
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc;
|
||||
float* __restrict__ c_tile_7 =
|
||||
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
const int32_t c_tile_stride = ldc * sizeof(float);
|
||||
|
||||
if (accum_c) {
|
||||
_tile_loadd(4, c_tile_4, c_tile_stride);
|
||||
_tile_loadd(5, c_tile_5, c_tile_stride);
|
||||
_tile_loadd(6, c_tile_6, c_tile_stride);
|
||||
_tile_loadd(7, c_tile_7, c_tile_stride);
|
||||
} else {
|
||||
_tile_zero(4);
|
||||
_tile_zero(5);
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k < k_times; ++k) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
|
||||
_tile_dpbf16ps(4, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
|
||||
_tile_dpbf16ps(5, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_dpbf16ps(6, 1, 2);
|
||||
_tile_dpbf16ps(7, 1, 3);
|
||||
|
||||
// update ptrs
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// Q buffer is prepacked
|
||||
a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// P buffer is not prepacked
|
||||
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
|
||||
_tile_stored(4, c_tile_4, c_tile_stride);
|
||||
_tile_stored(5, c_tile_5, c_tile_stride);
|
||||
_tile_stored(6, c_tile_6, c_tile_stride);
|
||||
_tile_stored(7, c_tile_7, c_tile_stride);
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
const int32_t m_0 = AMX_TILE_ROW_NUM;
|
||||
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
|
||||
config.rows[0] = m_0;
|
||||
config.rows[1] = m_1;
|
||||
config.rows[2] = AMX_TILE_ROW_NUM;
|
||||
config.rows[3] = AMX_TILE_ROW_NUM;
|
||||
config.rows[4] = m_0;
|
||||
config.rows[5] = m_0;
|
||||
config.rows[6] = m_1;
|
||||
config.rows[7] = m_1;
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
|
||||
// 1-2-2 pattern, for 0 < m <= 16
|
||||
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
|
||||
// m, m
|
||||
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
|
||||
// num should be 16
|
||||
// TILE 6, 7, (6, 7): store results C matrix, row num should be
|
||||
// m
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm122 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
|
||||
void* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm122<c10::BFloat16> {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
c10::BFloat16* __restrict__ a_tile,
|
||||
c10::BFloat16* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
|
||||
c10::BFloat16* __restrict__ a_tile_1 = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// q_buffer is prepacked
|
||||
return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// logits_buffer is row-major
|
||||
return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
const int64_t a_tile_stride = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// q_buffer is prepacked
|
||||
return AMX_TILE_ROW_BYTES;
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// logits_buffer is row-major
|
||||
return lda * sizeof(c10::BFloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// k_cache is prepacked
|
||||
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// v_cache is prepacked
|
||||
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
c10::BFloat16* __restrict__ b_tile_4 =
|
||||
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
c10::BFloat16* __restrict__ b_tile_5 =
|
||||
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
int64_t b_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
float* __restrict__ c_tile_6 = c_tile;
|
||||
float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
int64_t c_stride = ldc * sizeof(float);
|
||||
|
||||
const int32_t k_times =
|
||||
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
const int32_t k_group_times = k_times / 2;
|
||||
const bool has_tail = (k_times % 2 == 1);
|
||||
|
||||
if (accum_c) {
|
||||
_tile_loadd(6, c_tile_6, c_stride);
|
||||
_tile_loadd(7, c_tile_7, c_stride);
|
||||
} else {
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k < k_group_times; ++k) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_stream_loadd(4, b_tile_4, b_stride);
|
||||
_tile_dpbf16ps(6, 1, 4);
|
||||
_tile_stream_loadd(5, b_tile_5, b_stride);
|
||||
_tile_dpbf16ps(7, 1, 5);
|
||||
|
||||
// update ptrs
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// Q buffer is prepacked
|
||||
a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// P buffer is not prepacked
|
||||
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
|
||||
if (has_tail) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
}
|
||||
|
||||
_tile_stored(6, c_tile_6, c_stride);
|
||||
_tile_stored(7, c_tile_7, c_stride);
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
config.rows[0] = m;
|
||||
config.rows[1] = m;
|
||||
config.rows[2] = AMX_TILE_ROW_NUM;
|
||||
config.rows[3] = AMX_TILE_ROW_NUM;
|
||||
config.rows[4] = AMX_TILE_ROW_NUM;
|
||||
config.rows[5] = AMX_TILE_ROW_NUM;
|
||||
config.rows[6] = m;
|
||||
config.rows[7] = m;
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = scalar_t;
|
||||
using kv_cache_t = scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = scalar_t;
|
||||
|
||||
constexpr static int64_t BlockSizeAlignment =
|
||||
AMX_TILE_ROW_BYTES /
|
||||
sizeof(kv_cache_t); // KV token num unit of QK and PV phases
|
||||
constexpr static int64_t HeadDimAlignment =
|
||||
2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = 32;
|
||||
constexpr static int64_t HeadDim = head_dim;
|
||||
constexpr static ISA ISAType = ISA::AMX;
|
||||
constexpr static bool scale_on_logits = true;
|
||||
|
||||
public:
|
||||
AttentionImpl() : current_q_head_num_(0) {
|
||||
// Use all columns in AMX tiles
|
||||
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
|
||||
}
|
||||
|
||||
~AttentionImpl() { _tile_release(); }
|
||||
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
if (q_head_num > AMX_TILE_ROW_NUM) {
|
||||
if (q_head_num != current_q_head_num_) {
|
||||
current_q_head_num_ = q_head_num;
|
||||
TileGemm224<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
|
||||
}
|
||||
attention<TileGemm224<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
} else {
|
||||
if (q_head_num != current_q_head_num_) {
|
||||
current_q_head_num_ = q_head_num;
|
||||
TileGemm122<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
|
||||
}
|
||||
attention<TileGemm122<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
}
|
||||
|
||||
// k_cache_token_group_stride: stride of K cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t k_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment * head_dim;
|
||||
}
|
||||
|
||||
// v_cache_token_group_stride: stride of V cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t v_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment * (AMX_TILE_ROW_BYTES / 4);
|
||||
}
|
||||
|
||||
// v_cache_head_group_stride: stride of V cache when move to next
|
||||
// HeadDimAlignment head dims in a block
|
||||
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
|
||||
return block_size * HeadDimAlignment;
|
||||
}
|
||||
|
||||
static void copy_q_heads_tile(
|
||||
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
|
||||
scalar_t* __restrict__ q_buffer, const int32_t q_num,
|
||||
const int32_t q_heads_per_kv, const int64_t q_num_stride,
|
||||
const int64_t q_head_stride, const float scale) {
|
||||
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
|
||||
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
|
||||
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
|
||||
constexpr int64_t head_elem_num_pre_block =
|
||||
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
|
||||
|
||||
int32_t idx = 0;
|
||||
int8_t* __restrict__ q_buffer_iter = reinterpret_cast<int8_t*>(q_buffer);
|
||||
for (int32_t q_num_idx = 0; q_num_idx < q_num;
|
||||
++q_num_idx, src += q_num_stride) {
|
||||
scalar_t* __restrict__ src_iter = src;
|
||||
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv;
|
||||
++q_head_idx, src_iter += q_head_stride) {
|
||||
vec_op::unroll_loop<int32_t, head_size_block_num>(
|
||||
[&](int32_t head_size_block_idx) {
|
||||
// Use INT8Vec64 for 64 bytes block
|
||||
vec_op::INT8Vec64 vec(src_iter + head_size_block_idx *
|
||||
head_elem_num_pre_block);
|
||||
vec.save(q_buffer_iter + head_size_block_idx * AMX_TILE_BYTES);
|
||||
});
|
||||
|
||||
++idx;
|
||||
q_buffer_iter += AMX_TILE_ROW_BYTES;
|
||||
if ((idx & (AMX_TILE_ROW_NUM - 1)) == 0) {
|
||||
// head is in another amx tile
|
||||
q_buffer_iter -= AMX_TILE_ROW_NUM * AMX_TILE_ROW_BYTES;
|
||||
q_buffer_iter += head_size_block_num * AMX_TILE_BYTES;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reshape KV to AMX friendly layout
|
||||
static void reshape_and_cache(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
|
||||
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride) {
|
||||
// For AMX 2D tiles, size of each line is 64 bytes
|
||||
constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES;
|
||||
// For AMX B martix, N always is 16
|
||||
constexpr int64_t amx_b_tile_n_size = AMX_TILE_ROW_BYTES / 4;
|
||||
constexpr int64_t amx_b_tile_k_size = amx_tile_row_size / sizeof(scalar_t);
|
||||
// For now suppose block_size is divisible by amx_tile_column_num
|
||||
TORCH_CHECK_EQ(block_size % amx_b_tile_k_size, 0);
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
const int64_t pos = slot_mapping[token_idx];
|
||||
if (pos < 0) {
|
||||
// skip
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t block_idx = pos / block_size;
|
||||
const int64_t block_offset = pos % block_size;
|
||||
{
|
||||
// Write Key
|
||||
// Head elements should be packed as quand-words and stored in token
|
||||
// groups with (quadword_stride/4) tokens
|
||||
constexpr int64_t token_num_per_group = amx_tile_row_size / 4;
|
||||
static_assert(head_dim % (4 / sizeof(scalar_t)) == 0);
|
||||
constexpr int64_t quadword_num = head_dim / (4 / sizeof(scalar_t));
|
||||
const int32_t* key_start_quadword_ptr =
|
||||
reinterpret_cast<const int32_t*>(
|
||||
key + token_idx * key_token_num_stride +
|
||||
head_idx * key_head_num_stride);
|
||||
const int64_t group_idx = block_offset / token_num_per_group;
|
||||
const int64_t group_offset = block_offset % token_num_per_group;
|
||||
constexpr int64_t quadword_num_per_group =
|
||||
token_num_per_group * quadword_num;
|
||||
int32_t* key_cache_start_ptr =
|
||||
reinterpret_cast<int32_t*>(key_cache +
|
||||
block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride) +
|
||||
group_idx * quadword_num_per_group + group_offset;
|
||||
|
||||
#pragma GCC unroll 8
|
||||
for (int64_t i = 0, j = 0; j < quadword_num;
|
||||
i += token_num_per_group, ++j) {
|
||||
key_cache_start_ptr[i] = key_start_quadword_ptr[j];
|
||||
}
|
||||
}
|
||||
{
|
||||
// Write Value
|
||||
// Different from Key, block_size dimension is packed rather than
|
||||
// head_size dimension block_size dimension is packed as quand-words;
|
||||
constexpr int64_t token_num_per_sub_group = 4 / sizeof(scalar_t);
|
||||
const int64_t token_num_per_group = block_size;
|
||||
constexpr int64_t head_elems_per_group = amx_b_tile_n_size;
|
||||
const int64_t group_size = token_num_per_group * head_elems_per_group;
|
||||
// For now suppose head_dim is divisible by amx_b_tile_n_size
|
||||
static_assert(head_dim % head_elems_per_group == 0);
|
||||
constexpr int64_t group_num = head_dim / head_elems_per_group;
|
||||
const int64_t sub_group_idx = block_offset / token_num_per_sub_group;
|
||||
const int64_t sub_group_offset =
|
||||
block_offset % token_num_per_sub_group;
|
||||
|
||||
const scalar_t* value_start_ptr = value +
|
||||
token_idx * value_token_num_stride +
|
||||
head_idx * value_head_num_stride;
|
||||
scalar_t* value_cache_start_ptr =
|
||||
value_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride +
|
||||
sub_group_idx * token_num_per_sub_group * amx_b_tile_n_size +
|
||||
sub_group_offset;
|
||||
|
||||
for (int64_t i = 0; i < group_num; ++i) {
|
||||
#pragma GCC unroll head_elems_per_group
|
||||
for (int64_t j = 0, k = 0; j < head_elems_per_group;
|
||||
++j, k += token_num_per_sub_group) {
|
||||
value_cache_start_ptr[k] = value_start_ptr[j];
|
||||
}
|
||||
value_start_ptr += head_elems_per_group;
|
||||
value_cache_start_ptr += group_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
alignas(64) __tilecfg amx_tile_config_;
|
||||
int32_t current_q_head_num_;
|
||||
};
|
||||
} // namespace cpu_attention
|
||||
|
||||
#endif
|
||||
1977
csrc/cpu/cpu_attn_impl.hpp
Normal file
1977
csrc/cpu/cpu_attn_impl.hpp
Normal file
File diff suppressed because it is too large
Load Diff
63
csrc/cpu/cpu_attn_macros.h
Normal file
63
csrc/cpu/cpu_attn_macros.h
Normal file
@ -0,0 +1,63 @@
|
||||
#ifndef CPU_ATTN_MACROS_H
|
||||
#define CPU_ATTN_MACROS_H
|
||||
|
||||
// x86_64
|
||||
#ifdef __x86_64__
|
||||
#define FAST_SPINNING _mm_pause();
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#define DEFINE_FAST_EXP \
|
||||
const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); \
|
||||
const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); \
|
||||
const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); \
|
||||
const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); \
|
||||
const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); \
|
||||
const __m512 vec_exp_log2ef = \
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); \
|
||||
const __m512 vec_half = _mm512_set1_ps(0.5f); \
|
||||
const __m512 vec_one = _mm512_set1_ps(1.f); \
|
||||
const __m512 vec_zero = _mm512_set1_ps(0.f); \
|
||||
const __m512 vec_two = _mm512_set1_ps(2.f); \
|
||||
const __m512 vec_ln2f = \
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); \
|
||||
const __m512 vec_ln_flt_min = \
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); \
|
||||
const __m512 vec_ln_flt_max = \
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
|
||||
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
|
||||
const int n_mantissa_bits = 23; \
|
||||
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \
|
||||
always_inline)) { \
|
||||
__m512 values = vec.reg; \
|
||||
auto less_ln_flt_min_mask = \
|
||||
_mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); \
|
||||
auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); \
|
||||
vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); \
|
||||
auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); \
|
||||
auto vec_fx_i = _mm512_cvt_roundps_epi32( \
|
||||
vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); \
|
||||
vec_fx = _mm512_cvtepi32_ps(vec_fx_i); \
|
||||
auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); \
|
||||
auto vec_res = \
|
||||
_mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); \
|
||||
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); \
|
||||
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); \
|
||||
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); \
|
||||
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); \
|
||||
auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); \
|
||||
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); \
|
||||
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); \
|
||||
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); \
|
||||
auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); \
|
||||
vec_two_pow_n = _mm512_mask_blend_ps(less_ln_flt_min_mask, \
|
||||
vec_two_pow_n, vec_zero); \
|
||||
vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); \
|
||||
vec_res = _mm512_mul_ps(vec_res, vec_two); \
|
||||
vec_op::FP32Vec16 res(vec_res); \
|
||||
return res; \
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
248
csrc/cpu/cpu_attn_vec.hpp
Normal file
248
csrc/cpu/cpu_attn_vec.hpp
Normal file
@ -0,0 +1,248 @@
|
||||
#ifndef CPU_ATTN_VEC_HPP
|
||||
#define CPU_ATTN_VEC_HPP
|
||||
|
||||
#include "cpu_attn_impl.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
|
||||
namespace {
|
||||
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm82 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
switch (m_size) {
|
||||
case 1:
|
||||
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 2:
|
||||
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 3:
|
||||
case 4:
|
||||
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 5:
|
||||
case 6:
|
||||
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 7:
|
||||
case 8:
|
||||
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t M>
|
||||
static void gemm_micro(float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size, const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
static_assert(0 < M <= 8);
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
|
||||
kv_cache_t* __restrict__ curr_b_0 = b_tile;
|
||||
kv_cache_t* __restrict__ curr_b_1 = b_tile + 16;
|
||||
float* __restrict__ curr_c_0 = c_tile;
|
||||
float* __restrict__ curr_c_1 = c_tile + 16;
|
||||
|
||||
vec_op::FP32Vec16 c_regs[M * 2];
|
||||
if (accum_c) {
|
||||
float* __restrict__ curr_m_c_0 = curr_c_0;
|
||||
float* __restrict__ curr_m_c_1 = curr_c_1;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
|
||||
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
|
||||
|
||||
// update
|
||||
curr_m_c_0 += ldc;
|
||||
curr_m_c_1 += ldc;
|
||||
});
|
||||
}
|
||||
|
||||
float* __restrict__ curr_a = a_tile;
|
||||
for (int32_t k = 0; k < dynamic_k_size; ++k) {
|
||||
load_vec_t b_0_reg(curr_b_0);
|
||||
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
|
||||
load_vec_t b_1_reg(curr_b_1);
|
||||
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
|
||||
|
||||
float* __restrict__ curr_m_a = curr_a;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
float v = *curr_m_a;
|
||||
vec_op::FP32Vec16 a_reg(v);
|
||||
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
|
||||
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
|
||||
|
||||
// update
|
||||
curr_m_a += lda;
|
||||
});
|
||||
|
||||
// update
|
||||
curr_a += 1;
|
||||
curr_b_0 += ldb;
|
||||
curr_b_1 += ldb;
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i * 2].save(curr_c_0);
|
||||
c_regs[i * 2 + 1].save(curr_c_1);
|
||||
|
||||
// update
|
||||
curr_c_0 += ldc;
|
||||
curr_c_1 += ldc;
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// This is a general but naive implementation based on vector instructions
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
using kv_cache_t = scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = float;
|
||||
|
||||
constexpr static int64_t BlockSizeAlignment =
|
||||
32; // KV token num unit of QK and PV phases
|
||||
constexpr static int64_t HeadDimAlignment =
|
||||
32; // headdim num unit of PV phase
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = 8;
|
||||
constexpr static int64_t HeadDim = head_dim;
|
||||
constexpr static ISA ISAType = ISA::VEC;
|
||||
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
|
||||
|
||||
public:
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
attention<TileGemm82<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
|
||||
// k_cache_token_group_stride: stride of K cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t k_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
|
||||
// block_size], row-major
|
||||
}
|
||||
|
||||
// v_cache_token_group_stride: stride of V cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t v_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
|
||||
// head_dim], row-major
|
||||
}
|
||||
|
||||
// v_cache_head_group_stride: stride of V cache when move to next
|
||||
// HeadDimAlignment head dims in a block
|
||||
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
|
||||
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
|
||||
// row-major
|
||||
}
|
||||
|
||||
// Copy q to q_buffer and cast it to fp32
|
||||
static void copy_q_heads_tile(
|
||||
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
|
||||
float* __restrict__ q_buffer, const int32_t q_num,
|
||||
const int32_t q_heads_per_kv, const int64_t q_num_stride,
|
||||
const int64_t q_head_stride, float scale) {
|
||||
static_assert(head_dim % 16 == 0);
|
||||
constexpr int32_t unroll_size = head_dim / 16;
|
||||
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
|
||||
|
||||
vec_op::FP32Vec16 scale_vec(scale);
|
||||
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
|
||||
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
|
||||
scalar_t* __restrict__ curr_q =
|
||||
src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
|
||||
float* __restrict__ curr_q_buffer =
|
||||
q_buffer + q_num_idx * q_heads_per_kv * head_dim +
|
||||
q_head_idx * head_dim;
|
||||
|
||||
vec_op::unroll_loop<int32_t, unroll_size>([&](int32_t i) {
|
||||
load_vec_t vec(curr_q);
|
||||
vec_op::FP32Vec16 fp32_vec(vec);
|
||||
fp32_vec = fp32_vec * scale_vec;
|
||||
fp32_vec.save(curr_q_buffer);
|
||||
|
||||
curr_q += 16;
|
||||
curr_q_buffer += 16;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reshape K as column-major and V as row-major
|
||||
static void reshape_and_cache(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
|
||||
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride) {
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
const int64_t pos = slot_mapping[token_idx];
|
||||
if (pos < 0) {
|
||||
// skip
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t block_idx = pos / block_size;
|
||||
const int64_t block_offset = pos % block_size;
|
||||
{
|
||||
// Write Key as column-major
|
||||
const scalar_t* key_start_ptr = key +
|
||||
token_idx * key_token_num_stride +
|
||||
head_idx * key_head_num_stride;
|
||||
scalar_t* key_cache_start_ptr =
|
||||
key_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride + block_offset;
|
||||
|
||||
#pragma GCC unroll 8
|
||||
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
|
||||
key_cache_start_ptr[j] = key_start_ptr[i];
|
||||
}
|
||||
}
|
||||
{
|
||||
// Write Value as row-major
|
||||
const scalar_t* value_start_ptr = value +
|
||||
token_idx * value_token_num_stride +
|
||||
head_idx * value_head_num_stride;
|
||||
scalar_t* value_cache_start_ptr =
|
||||
value_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride + block_offset * head_dim;
|
||||
std::memcpy(value_cache_start_ptr, value_start_ptr,
|
||||
sizeof(scalar_t) * head_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace cpu_attention
|
||||
|
||||
#endif
|
||||
171
csrc/cpu/cpu_attn_vec16.hpp
Normal file
171
csrc/cpu/cpu_attn_vec16.hpp
Normal file
@ -0,0 +1,171 @@
|
||||
#ifndef CPU_ATTN_VEC16_HPP
|
||||
#define CPU_ATTN_VEC16_HPP
|
||||
|
||||
#include "cpu_attn_vec.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
|
||||
namespace {
|
||||
// 16-1-16 pattern, 16 regs for A, 1 regs for B, 16 regs for C, [16, K] @ [k,
|
||||
// 16]
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm161 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
switch (m_size) {
|
||||
case 1:
|
||||
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 2:
|
||||
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 3:
|
||||
case 4:
|
||||
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 5:
|
||||
case 6:
|
||||
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 7:
|
||||
case 8:
|
||||
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 9:
|
||||
case 10:
|
||||
case 11:
|
||||
case 12:
|
||||
gemm_micro<12>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 13:
|
||||
case 14:
|
||||
case 15:
|
||||
case 16:
|
||||
gemm_micro<16>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t M>
|
||||
static void gemm_micro(float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size, const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
static_assert(0 < M <= 16);
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
|
||||
kv_cache_t* __restrict__ curr_b_0 = b_tile;
|
||||
float* __restrict__ curr_c_0 = c_tile;
|
||||
|
||||
vec_op::FP32Vec16 c_regs[M];
|
||||
if (accum_c) {
|
||||
float* __restrict__ curr_m_c_0 = curr_c_0;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i] = vec_op::FP32Vec16(curr_m_c_0);
|
||||
|
||||
// update
|
||||
curr_m_c_0 += ldc;
|
||||
});
|
||||
}
|
||||
|
||||
float* __restrict__ curr_a = a_tile;
|
||||
for (int32_t k = 0; k < dynamic_k_size; ++k) {
|
||||
load_vec_t b_0_reg(curr_b_0);
|
||||
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
|
||||
|
||||
float* __restrict__ curr_m_a = curr_a;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
float v = *curr_m_a;
|
||||
vec_op::FP32Vec16 a_reg(v);
|
||||
c_regs[i] = c_regs[i] + a_reg * fp32_b_0_reg;
|
||||
|
||||
// update
|
||||
curr_m_a += lda;
|
||||
});
|
||||
|
||||
// update
|
||||
curr_a += 1;
|
||||
curr_b_0 += ldb;
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i].save(curr_c_0);
|
||||
|
||||
// update
|
||||
curr_c_0 += ldc;
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// This is a general but naive implementation based on vector instructions
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::VEC16, scalar_t, head_dim>
|
||||
: public AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
using kv_cache_t = scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = float;
|
||||
|
||||
constexpr static int64_t BlockSizeAlignment =
|
||||
16; // KV token num unit of QK and PV phases
|
||||
constexpr static int64_t HeadDimAlignment =
|
||||
16; // headdim num unit of PV phase
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = 16;
|
||||
constexpr static int64_t HeadDim = head_dim;
|
||||
constexpr static ISA ISAType = ISA::VEC16;
|
||||
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
|
||||
|
||||
public:
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
attention<TileGemm161<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
|
||||
// k_cache_token_group_stride: stride of K cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t k_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
|
||||
// block_size], row-major
|
||||
}
|
||||
|
||||
// v_cache_token_group_stride: stride of V cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t v_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
|
||||
// head_dim], row-major
|
||||
}
|
||||
|
||||
// v_cache_head_group_stride: stride of V cache when move to next
|
||||
// HeadDimAlignment head dims in a block
|
||||
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
|
||||
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
|
||||
// row-major
|
||||
}
|
||||
};
|
||||
} // namespace cpu_attention
|
||||
|
||||
#endif
|
||||
@ -40,6 +40,23 @@ namespace vec_op {
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
// Function to get the timestamp using RDTSCP
|
||||
FORCE_INLINE uint64_t bench_timestamp() {
|
||||
unsigned int cycles_low, cycles_high;
|
||||
asm volatile(
|
||||
".intel_syntax noprefix\n\t"
|
||||
"CPUID\n\t" // Serialize instruction stream to ensure previous
|
||||
// instructions complete
|
||||
"RDTSCP\n\t" // Read TSC and core ID
|
||||
"mov %0, edx\n\t" // Store high 32 bits of TSC
|
||||
"mov %1, eax\n\t" // Store low 32 bits of TSC
|
||||
".att_syntax"
|
||||
: "=r"(cycles_high), "=r"(cycles_low)::"rax", "rbx", "rcx",
|
||||
"rdx" // Clobbered registers
|
||||
);
|
||||
return (uint64_t)cycles_high << 32 | cycles_low;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
@ -407,6 +424,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
|
||||
|
||||
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
@ -446,9 +465,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16& data)
|
||||
: reg_low(data.reg_low), reg_high(data.reg_high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg_low((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
|
||||
@ -504,6 +520,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
_mm256_div_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_max_ps(reg_low, b.reg_low),
|
||||
_mm256_max_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
float reduce_max() const {
|
||||
__m256 v = _mm256_max_ps(reg_low, reg_high);
|
||||
// Permute to compare elements within 128-bit lanes
|
||||
__m256 v_shuffled = _mm256_permute_ps(
|
||||
v, 0b00001011); // Swap halves within each 128-bit lane
|
||||
__m256 v_max = _mm256_max_ps(v, v_shuffled);
|
||||
|
||||
v_shuffled = _mm256_permute_ps(
|
||||
v_max, 0b00000001); // Shuffle elements within each 128-bit lane
|
||||
v_max = _mm256_max_ps(v_max, v_shuffled);
|
||||
|
||||
// Permute to compare elements between 128-bit lanes
|
||||
v_shuffled =
|
||||
_mm256_permute2f128_ps(v_max, v_max, 0b00000001); // Swap 128-bit lanes
|
||||
v_max = _mm256_max_ps(v_max, v_shuffled);
|
||||
|
||||
// At this point, the maximum value is present in all elements of v_max.
|
||||
// Extract the first element for the scalar result.
|
||||
return _mm256_cvtss_f32(v_max); // Extract the lowest 32-bit float
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
FP32Vec8 low = FP32Vec8(reg_low);
|
||||
FP32Vec8 high = FP32Vec8(reg_high);
|
||||
@ -642,7 +684,7 @@ inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm256_insertf128_si256(
|
||||
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
|
||||
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||
FP16Vec8(FP32Vec8(v.reg_high)).reg, 1)) {}
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include "common/memory.hpp"
|
||||
|
||||
#include "dnnl_helper.h"
|
||||
#include "scratchpad_manager.h"
|
||||
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
@ -22,23 +23,6 @@ void release_dnnl_matmul_handler(int64_t handler) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
|
||||
this->realloc(allocation_unit * 128);
|
||||
}
|
||||
|
||||
void DNNLScratchPadManager::realloc(size_t new_size) {
|
||||
new_size = round(new_size);
|
||||
if (new_size > size_) {
|
||||
ptr_ = std::aligned_alloc(64, new_size);
|
||||
size_ = new_size;
|
||||
}
|
||||
}
|
||||
|
||||
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
|
||||
static DNNLScratchPadManager manager;
|
||||
return &manager;
|
||||
}
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache {
|
||||
public:
|
||||
|
||||
@ -59,30 +59,6 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
|
||||
class DNNLScratchPadManager {
|
||||
public:
|
||||
static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB
|
||||
|
||||
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
|
||||
|
||||
DNNLScratchPadManager();
|
||||
|
||||
template <typename T>
|
||||
T* get_data() {
|
||||
return reinterpret_cast<T*>(ptr_);
|
||||
}
|
||||
|
||||
static size_t round(size_t size) {
|
||||
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
|
||||
}
|
||||
|
||||
void realloc(size_t new_size);
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
void* ptr_;
|
||||
};
|
||||
|
||||
class DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
virtual ~DNNLMatMulPrimitiveHandler() = default;
|
||||
|
||||
23
csrc/cpu/scratchpad_manager.cpp
Normal file
23
csrc/cpu/scratchpad_manager.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "scratchpad_manager.h"
|
||||
|
||||
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
|
||||
this->realloc(allocation_unit * 128);
|
||||
}
|
||||
|
||||
void DNNLScratchPadManager::realloc(size_t new_size) {
|
||||
new_size = round(new_size);
|
||||
if (new_size > size_) {
|
||||
if (ptr_ != nullptr) {
|
||||
std::free(ptr_);
|
||||
}
|
||||
ptr_ = std::aligned_alloc(64, new_size);
|
||||
size_ = new_size;
|
||||
}
|
||||
}
|
||||
|
||||
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
|
||||
static DNNLScratchPadManager manager;
|
||||
return &manager;
|
||||
}
|
||||
31
csrc/cpu/scratchpad_manager.h
Normal file
31
csrc/cpu/scratchpad_manager.h
Normal file
@ -0,0 +1,31 @@
|
||||
#ifndef SCRATCHPAD_MANAGER_H
|
||||
#define SCRATCHPAD_MANAGER_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
|
||||
class DNNLScratchPadManager {
|
||||
public:
|
||||
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
|
||||
|
||||
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
|
||||
|
||||
DNNLScratchPadManager();
|
||||
|
||||
template <typename T>
|
||||
T* get_data() {
|
||||
return reinterpret_cast<T*>(ptr_);
|
||||
}
|
||||
|
||||
static size_t round(size_t size) {
|
||||
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
|
||||
}
|
||||
|
||||
void realloc(size_t new_size);
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
void* ptr_;
|
||||
};
|
||||
|
||||
#endif
|
||||
@ -192,7 +192,7 @@ class SHMManager {
|
||||
const int group_size)
|
||||
: _rank(rank),
|
||||
_group_size(group_size),
|
||||
_thread_num(torch::get_num_threads()),
|
||||
_thread_num(omp_get_max_threads()),
|
||||
_shm_names({""}),
|
||||
_shared_mem_ptrs({nullptr}),
|
||||
_shm_ctx(nullptr) {
|
||||
|
||||
@ -74,25 +74,35 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni);
|
||||
|
||||
torch::Tensor get_scheduler_metadata(
|
||||
const int64_t num_req, const int64_t num_heads_q,
|
||||
const int64_t num_heads_kv, const int64_t head_dim,
|
||||
const torch::Tensor& seq_lens, at::ScalarType dtype,
|
||||
const torch::Tensor& query_start_loc, const bool casual,
|
||||
const int64_t window_size, const std::string& isa_hint,
|
||||
const bool enable_kv_split);
|
||||
|
||||
void cpu_attn_reshape_and_cache(const torch::Tensor& key,
|
||||
const torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
const torch::Tensor& slot_mapping,
|
||||
const std::string& isa);
|
||||
|
||||
void cpu_attention_with_kv_cache(
|
||||
const torch::Tensor& query, const torch::Tensor& key_cache,
|
||||
const torch::Tensor& value_cache, torch::Tensor& output,
|
||||
const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
|
||||
const double scale, const bool causal,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const int64_t sliding_window_left, const int64_t sliding_window_right,
|
||||
const torch::Tensor& block_table, const double softcap,
|
||||
const torch::Tensor& scheduler_metadata,
|
||||
const std::optional<torch::Tensor>& s_aux);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
// Attention ops
|
||||
// Compute the attention between an input query and the cached keys/values
|
||||
// using PagedAttention.
|
||||
ops.def(
|
||||
"paged_attention_v1("
|
||||
" Tensor! out, Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
|
||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||
|
||||
ops.def(
|
||||
"dynamic_4bit_int_moe("
|
||||
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
|
||||
@ -102,20 +112,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
|
||||
|
||||
// PagedAttention V2.
|
||||
ops.def(
|
||||
"paged_attention_v2("
|
||||
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
||||
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
||||
|
||||
// Activation ops
|
||||
|
||||
// Activation function used in SwiGLU.
|
||||
@ -259,37 +255,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||
&int8_scaled_mm_with_quant);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
// Cache ops
|
||||
// Swap in (out) the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
|
||||
|
||||
// Copy the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
||||
"Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
cache_ops.def(
|
||||
"reshape_and_cache(Tensor key, Tensor value,"
|
||||
" Tensor! key_cache, Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
|
||||
// CPU attention kernels
|
||||
ops.def(
|
||||
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
|
||||
"int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
|
||||
"query_start_loc, bool casual, int window_size, str isa_hint, bool "
|
||||
"enable_kv_split) -> Tensor",
|
||||
&get_scheduler_metadata);
|
||||
ops.def(
|
||||
"cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
|
||||
"key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
|
||||
"isa) -> ()",
|
||||
&cpu_attn_reshape_and_cache);
|
||||
ops.def(
|
||||
"cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
|
||||
"value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
|
||||
"seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
|
||||
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
|
||||
"float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()",
|
||||
&cpu_attention_with_kv_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
|
||||
@ -45,6 +45,16 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
// Memory node binding
|
||||
if (numa_available() != -1) {
|
||||
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
|
||||
// Verify all CPUs are on the same NUMA node
|
||||
for (size_t i = 1; i < omp_cpu_ids.size(); ++i) {
|
||||
int node_id = numa_node_of_cpu(omp_cpu_ids[i]);
|
||||
TORCH_CHECK(node_id == mem_node_id, "CPU ", omp_cpu_ids[i],
|
||||
" is on NUMA node ", node_id, ", but CPU ",
|
||||
omp_cpu_ids.front(), " is on NUMA node ", mem_node_id,
|
||||
". All CPUs should be on the same NUMA node for optimal "
|
||||
"performance. Memory will be bound to NUMA node ",
|
||||
mem_node_id, ".");
|
||||
}
|
||||
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
|
||||
bitmask* src_mask = numa_get_membind();
|
||||
|
||||
|
||||
@ -3,14 +3,58 @@
|
||||
// need to be unsigned long long
|
||||
#include <iostream>
|
||||
|
||||
#include "cumem_allocator_compat.h"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
static const char* PYARGS_PARSE = "KKKK";
|
||||
#else
|
||||
#include <cstdlib>
|
||||
#include <cerrno>
|
||||
#include <climits>
|
||||
|
||||
// Default chunk size 256MB for ROCm. Can be overridden at runtime by the
|
||||
// environment variable VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE, specified in megabytes
|
||||
// (MB). The env value is parsed with strtoull as an integer number of MB
|
||||
// (decimal or 0x hex). The parsed MB value is converted to bytes. If
|
||||
// parsing fails, the value is 0, or the multiplication would overflow,
|
||||
// the default (256MB) is used.
|
||||
static const unsigned long long DEFAULT_MEMCREATE_CHUNK_SIZE =
|
||||
(256ULL * 1024ULL * 1024ULL);
|
||||
|
||||
static unsigned long long get_memcreate_chunk_size() {
|
||||
const char* env = getenv("VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE");
|
||||
if (!env) return DEFAULT_MEMCREATE_CHUNK_SIZE;
|
||||
char* endptr = nullptr;
|
||||
errno = 0;
|
||||
unsigned long long val_mb = strtoull(env, &endptr, 0);
|
||||
if (endptr == env || errno != 0) {
|
||||
// parsing failed, fallback to default
|
||||
return DEFAULT_MEMCREATE_CHUNK_SIZE;
|
||||
}
|
||||
if (val_mb == 0) return DEFAULT_MEMCREATE_CHUNK_SIZE;
|
||||
|
||||
const unsigned long long MB = 1024ULL * 1024ULL;
|
||||
// guard against overflow when converting MB -> bytes
|
||||
if (val_mb > (ULLONG_MAX / MB)) {
|
||||
return DEFAULT_MEMCREATE_CHUNK_SIZE;
|
||||
}
|
||||
return val_mb * MB;
|
||||
}
|
||||
|
||||
static inline unsigned long long my_min(unsigned long long a,
|
||||
unsigned long long b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static const char* PYARGS_PARSE = "KKKO";
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda.h>
|
||||
|
||||
char error_msg[10240]; // 10KB buffer to store error messages
|
||||
CUresult no_error = CUresult(0);
|
||||
@ -49,7 +93,12 @@ void ensure_context(unsigned long long device) {
|
||||
}
|
||||
|
||||
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
#ifndef USE_ROCM
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
#else
|
||||
CUmemGenericAllocationHandle** p_memHandle,
|
||||
unsigned long long* chunk_sizes, size_t num_chunks) {
|
||||
#endif
|
||||
ensure_context(device);
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
@ -58,6 +107,7 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Allocate memory using cuMemCreate
|
||||
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
|
||||
if (error_code != 0) {
|
||||
@ -67,6 +117,39 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
if (error_code != 0) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
for (auto i = 0; i < num_chunks; ++i) {
|
||||
CUDA_CHECK(cuMemCreate(p_memHandle[i], chunk_sizes[i], &prop, 0));
|
||||
if (error_code != 0) {
|
||||
// Clean up previously created handles
|
||||
for (auto j = 0; j < i; ++j) {
|
||||
cuMemRelease(*(p_memHandle[j]));
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
unsigned long long allocated_size = 0;
|
||||
for (auto i = 0; i < num_chunks; ++i) {
|
||||
void* map_addr = (void*)((uintptr_t)d_mem + allocated_size);
|
||||
CUDA_CHECK(cuMemMap(map_addr, chunk_sizes[i], 0, *(p_memHandle[i]), 0));
|
||||
if (error_code != 0) {
|
||||
// unmap previously mapped chunks
|
||||
unsigned long long unmapped_size = 0;
|
||||
for (auto j = 0; j < i; ++j) {
|
||||
void* unmap_addr = (void*)((uintptr_t)d_mem + unmapped_size);
|
||||
cuMemUnmap(unmap_addr, chunk_sizes[j]);
|
||||
unmapped_size += chunk_sizes[j];
|
||||
}
|
||||
// release all created handles
|
||||
for (auto j = 0; j < num_chunks; ++j) {
|
||||
cuMemRelease(*(p_memHandle[j]));
|
||||
}
|
||||
return;
|
||||
}
|
||||
allocated_size += chunk_sizes[i];
|
||||
}
|
||||
#endif
|
||||
|
||||
CUmemAccessDesc accessDesc = {};
|
||||
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
accessDesc.location.id = device;
|
||||
@ -82,10 +165,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
|
||||
void unmap_and_release(unsigned long long device, ssize_t size,
|
||||
CUdeviceptr d_mem,
|
||||
#ifndef USE_ROCM
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
#else
|
||||
CUmemGenericAllocationHandle** p_memHandle,
|
||||
unsigned long long* chunk_sizes, size_t num_chunks) {
|
||||
#endif
|
||||
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
|
||||
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
ensure_context(device);
|
||||
#ifndef USE_ROCM
|
||||
CUDA_CHECK(cuMemUnmap(d_mem, size));
|
||||
if (error_code != 0) {
|
||||
return;
|
||||
@ -94,6 +183,30 @@ void unmap_and_release(unsigned long long device, ssize_t size,
|
||||
if (error_code != 0) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
unsigned long long allocated_size = 0;
|
||||
CUresult first_error = no_error;
|
||||
|
||||
for (auto i = 0; i < num_chunks; ++i) {
|
||||
void* map_addr = (void*)((uintptr_t)d_mem + allocated_size);
|
||||
CUresult status = cuMemUnmap(map_addr, chunk_sizes[i]);
|
||||
if (status != no_error && first_error == no_error) {
|
||||
first_error = status;
|
||||
}
|
||||
allocated_size += chunk_sizes[i];
|
||||
}
|
||||
|
||||
for (auto i = 0; i < num_chunks; ++i) {
|
||||
CUresult status = cuMemRelease(*(p_memHandle[i]));
|
||||
if (status != no_error && first_error == no_error) {
|
||||
first_error = status;
|
||||
}
|
||||
}
|
||||
|
||||
if (first_error != no_error) {
|
||||
CUDA_CHECK(first_error);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
PyObject* create_tuple_from_c_integers(unsigned long long a,
|
||||
@ -120,6 +233,36 @@ PyObject* create_tuple_from_c_integers(unsigned long long a,
|
||||
return tuple; // Return the created tuple
|
||||
}
|
||||
|
||||
PyObject* create_tuple_from_c_mixed(unsigned long long a, unsigned long long b,
|
||||
unsigned long long c,
|
||||
CUmemGenericAllocationHandle** vec,
|
||||
unsigned long long* chunk_sizes,
|
||||
size_t num_chunks) {
|
||||
PyObject* tuple = PyTuple_New(4);
|
||||
if (!tuple) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// PyObject* list = PyList_New(vec.size());
|
||||
PyObject* list = PyList_New(num_chunks);
|
||||
for (auto i = 0; i < num_chunks; ++i) {
|
||||
PyObject* addr_size_pair = PyTuple_New(2);
|
||||
PyObject* addr = PyLong_FromUnsignedLongLong((unsigned long long)(vec[i]));
|
||||
PyObject* size =
|
||||
PyLong_FromUnsignedLongLong((unsigned long long)(chunk_sizes[i]));
|
||||
PyTuple_SetItem(addr_size_pair, 0, addr);
|
||||
PyTuple_SetItem(addr_size_pair, 1, size);
|
||||
PyList_SetItem(list, i, addr_size_pair);
|
||||
}
|
||||
|
||||
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a));
|
||||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
|
||||
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
|
||||
PyTuple_SetItem(tuple, 3, list);
|
||||
|
||||
return tuple;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Our exported C functions that call Python:
|
||||
|
||||
@ -147,14 +290,55 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
|
||||
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
|
||||
|
||||
CUdeviceptr d_mem;
|
||||
#ifndef USE_ROCM
|
||||
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
|
||||
if (error_code != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
#else
|
||||
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, granularity, 0, 0));
|
||||
if (error_code != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// allocate the CUmemGenericAllocationHandle
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)malloc(
|
||||
sizeof(CUmemGenericAllocationHandle));
|
||||
#else
|
||||
// Make sure chunk size is aligned with hardware granularity. The base
|
||||
// chunk size can be configured via environment variable
|
||||
// ``VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE``; otherwise
|
||||
// DEFAULT_MEMCREATE_CHUNK_SIZE is used.
|
||||
size_t base_chunk = (size_t)get_memcreate_chunk_size();
|
||||
size_t aligned_chunk_size =
|
||||
((base_chunk + granularity - 1) / granularity) * granularity;
|
||||
size_t num_chunks =
|
||||
(alignedSize + aligned_chunk_size - 1) / aligned_chunk_size;
|
||||
CUmemGenericAllocationHandle** p_memHandle =
|
||||
(CUmemGenericAllocationHandle**)malloc(
|
||||
num_chunks * sizeof(CUmemGenericAllocationHandle*));
|
||||
unsigned long long* chunk_sizes =
|
||||
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
|
||||
for (auto i = 0; i < num_chunks; ++i) {
|
||||
p_memHandle[i] = (CUmemGenericAllocationHandle*)malloc(
|
||||
sizeof(CUmemGenericAllocationHandle));
|
||||
if (p_memHandle[i] == nullptr) {
|
||||
std::cerr << "ERROR: malloc failed for p_memHandle[" << i << "].\n";
|
||||
for (auto j = 0; j < i; ++j) {
|
||||
free(p_memHandle[j]);
|
||||
}
|
||||
free(p_memHandle);
|
||||
free(chunk_sizes);
|
||||
return nullptr;
|
||||
}
|
||||
chunk_sizes[i] = (unsigned long long)my_min(
|
||||
(unsigned long long)(alignedSize - i * aligned_chunk_size),
|
||||
(unsigned long long)aligned_chunk_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!g_python_malloc_callback) {
|
||||
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
|
||||
@ -164,9 +348,15 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
#ifndef USE_ROCM
|
||||
PyObject* arg_tuple = create_tuple_from_c_integers(
|
||||
(unsigned long long)device, (unsigned long long)alignedSize,
|
||||
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
|
||||
#else
|
||||
PyObject* arg_tuple = create_tuple_from_c_mixed(
|
||||
(unsigned long long)device, (unsigned long long)alignedSize,
|
||||
(unsigned long long)d_mem, p_memHandle, chunk_sizes, num_chunks);
|
||||
#endif
|
||||
|
||||
// Call g_python_malloc_callback
|
||||
PyObject* py_result =
|
||||
@ -182,7 +372,27 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// do the final mapping
|
||||
#ifndef USE_ROCM
|
||||
create_and_map(device, alignedSize, d_mem, p_memHandle);
|
||||
#else
|
||||
create_and_map(device, alignedSize, d_mem, p_memHandle, chunk_sizes,
|
||||
num_chunks);
|
||||
free(chunk_sizes);
|
||||
#endif
|
||||
|
||||
if (error_code != 0) {
|
||||
// free address and the handle
|
||||
CUDA_CHECK(cuMemAddressFree(d_mem, alignedSize));
|
||||
#ifndef USE_ROCM
|
||||
free(p_memHandle);
|
||||
#else
|
||||
for (size_t i = 0; i < num_chunks; ++i) {
|
||||
free(p_memHandle[i]);
|
||||
}
|
||||
free(p_memHandle);
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return (void*)d_mem;
|
||||
}
|
||||
@ -206,36 +416,96 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
|
||||
|
||||
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
Py_XDECREF(py_result);
|
||||
Py_XDECREF(py_ptr);
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
unsigned long long recv_d_mem;
|
||||
#ifndef USE_ROCM
|
||||
unsigned long long recv_p_memHandle;
|
||||
#else
|
||||
PyObject* recv_p_memHandle;
|
||||
#endif
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
|
||||
if (!PyArg_ParseTuple(py_result, PYARGS_PARSE, &recv_device, &recv_size,
|
||||
&recv_d_mem, &recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
Py_XDECREF(py_result);
|
||||
Py_XDECREF(py_ptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// For ROCm, copy the Python list of (addr,size) pairs into C arrays while
|
||||
// holding the GIL. Then release the GIL and call the unmap/release helper
|
||||
// using the copied arrays. This avoids calling PyList_* APIs without the
|
||||
// GIL (which is undefined behavior and can crash when called from other
|
||||
// threads).
|
||||
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
|
||||
#ifdef USE_ROCM
|
||||
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
|
||||
CUmemGenericAllocationHandle** p_memHandle =
|
||||
(CUmemGenericAllocationHandle**)malloc(
|
||||
num_chunks * sizeof(CUmemGenericAllocationHandle*));
|
||||
if (p_memHandle == nullptr) {
|
||||
Py_DECREF(py_ptr);
|
||||
Py_DECREF(py_result);
|
||||
PyGILState_Release(gstate);
|
||||
std::cerr << "ERROR: malloc failed for p_memHandle in my_free."
|
||||
<< std::endl;
|
||||
return;
|
||||
}
|
||||
unsigned long long* chunk_sizes =
|
||||
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
|
||||
if (chunk_sizes == nullptr) {
|
||||
free(p_memHandle);
|
||||
Py_DECREF(py_ptr);
|
||||
Py_DECREF(py_result);
|
||||
PyGILState_Release(gstate);
|
||||
std::cerr << "ERROR: malloc failed for chunk_sizes in my_free."
|
||||
<< std::endl;
|
||||
return;
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < num_chunks; ++i) {
|
||||
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
|
||||
PyObject* addr_py = PyTuple_GetItem(item, 0);
|
||||
PyObject* size_py = PyTuple_GetItem(item, 1);
|
||||
p_memHandle[i] =
|
||||
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
|
||||
chunk_sizes[i] = (unsigned long long)PyLong_AsUnsignedLongLong(size_py);
|
||||
}
|
||||
|
||||
// Drop temporary Python refs, then release the GIL before calling into
|
||||
// non-Python APIs.
|
||||
Py_DECREF(py_ptr);
|
||||
Py_DECREF(py_result);
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// recv_size == size
|
||||
// recv_device == device
|
||||
unmap_and_release(device, size, d_mem, p_memHandle, chunk_sizes, num_chunks);
|
||||
#else
|
||||
// Non-ROCm path: simple integer handle already extracted; drop temporary
|
||||
// Python refs while still holding the GIL, then release it.
|
||||
Py_DECREF(py_ptr);
|
||||
Py_DECREF(py_result);
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// Free memory
|
||||
|
||||
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
unmap_and_release(device, size, d_mem, p_memHandle);
|
||||
#endif
|
||||
|
||||
// free address and the handle
|
||||
CUDA_CHECK(cuMemAddressFree(d_mem, size));
|
||||
if (error_code != 0) {
|
||||
return;
|
||||
#ifndef USE_ROCM
|
||||
free(p_memHandle);
|
||||
#else
|
||||
for (auto i = 0; i < num_chunks; ++i) {
|
||||
free(p_memHandle[i]);
|
||||
}
|
||||
free(p_memHandle);
|
||||
free(chunk_sizes);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@ -271,19 +541,87 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
unsigned long long recv_d_mem;
|
||||
#ifndef USE_ROCM
|
||||
unsigned long long recv_p_memHandle;
|
||||
#else
|
||||
PyObject* recv_p_memHandle;
|
||||
#endif
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
if (!PyArg_ParseTuple(args, PYARGS_PARSE, &recv_device, &recv_size,
|
||||
&recv_d_mem, &recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
#ifndef USE_ROCM
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
#else
|
||||
if (!PyList_Check(recv_p_memHandle)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Expected a list for the 4th argument on ROCm");
|
||||
return nullptr;
|
||||
}
|
||||
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
|
||||
if (num_chunks < 0) {
|
||||
return nullptr; // PyList_Size sets an exception on error.
|
||||
}
|
||||
CUmemGenericAllocationHandle** p_memHandle =
|
||||
(CUmemGenericAllocationHandle**)malloc(
|
||||
num_chunks * sizeof(CUmemGenericAllocationHandle*));
|
||||
if (p_memHandle == nullptr) {
|
||||
PyErr_SetString(PyExc_MemoryError, "malloc failed for p_memHandle");
|
||||
return nullptr;
|
||||
}
|
||||
unsigned long long* chunk_sizes =
|
||||
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
|
||||
if (chunk_sizes == nullptr) {
|
||||
free(p_memHandle);
|
||||
PyErr_SetString(PyExc_MemoryError, "malloc failed for chunk_sizes");
|
||||
return nullptr;
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < num_chunks; ++i) {
|
||||
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
|
||||
if (item == nullptr || !PyTuple_Check(item) || PyTuple_Size(item) != 2) {
|
||||
free(p_memHandle);
|
||||
free(chunk_sizes);
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"List items must be tuples of size 2 (handle_addr, size)");
|
||||
return nullptr;
|
||||
}
|
||||
PyObject* addr_py = PyTuple_GetItem(item, 0);
|
||||
PyObject* size_py = PyTuple_GetItem(item, 1);
|
||||
if (addr_py == nullptr || size_py == nullptr) {
|
||||
free(p_memHandle);
|
||||
free(chunk_sizes);
|
||||
return nullptr; // PyTuple_GetItem sets an exception
|
||||
}
|
||||
p_memHandle[i] =
|
||||
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
|
||||
if (PyErr_Occurred()) {
|
||||
free(p_memHandle);
|
||||
free(chunk_sizes);
|
||||
return nullptr;
|
||||
}
|
||||
chunk_sizes[i] = (unsigned long long)PyLong_AsUnsignedLongLong(size_py);
|
||||
if (PyErr_Occurred()) {
|
||||
free(p_memHandle);
|
||||
free(chunk_sizes);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle, chunk_sizes,
|
||||
num_chunks);
|
||||
|
||||
free(p_memHandle);
|
||||
free(chunk_sizes);
|
||||
#endif
|
||||
|
||||
if (error_code != 0) {
|
||||
error_code = no_error;
|
||||
@ -301,19 +639,56 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
unsigned long long recv_d_mem;
|
||||
#ifndef USE_ROCM
|
||||
unsigned long long recv_p_memHandle;
|
||||
#else
|
||||
PyObject* recv_p_memHandle;
|
||||
#endif
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
if (!PyArg_ParseTuple(args, PYARGS_PARSE, &recv_device, &recv_size,
|
||||
&recv_d_mem, &recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
#ifndef USE_ROCM
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
#else
|
||||
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
|
||||
CUmemGenericAllocationHandle** p_memHandle =
|
||||
(CUmemGenericAllocationHandle**)malloc(
|
||||
num_chunks * sizeof(CUmemGenericAllocationHandle*));
|
||||
if (p_memHandle == nullptr) {
|
||||
PyErr_SetString(PyExc_MemoryError, "malloc failed for p_memHandle");
|
||||
return nullptr;
|
||||
}
|
||||
unsigned long long* chunk_sizes =
|
||||
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
|
||||
if (chunk_sizes == nullptr) {
|
||||
free(p_memHandle);
|
||||
PyErr_SetString(PyExc_MemoryError, "malloc failed for chunk_sizes");
|
||||
return nullptr;
|
||||
}
|
||||
for (auto i = 0; i < num_chunks; ++i) {
|
||||
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
|
||||
PyObject* addr_py = PyTuple_GetItem(item, 0);
|
||||
PyObject* size_py = PyTuple_GetItem(item, 1);
|
||||
p_memHandle[i] =
|
||||
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
|
||||
chunk_sizes[i] = PyLong_AsUnsignedLongLong(size_py);
|
||||
}
|
||||
|
||||
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle, chunk_sizes,
|
||||
num_chunks);
|
||||
|
||||
free(p_memHandle);
|
||||
free(chunk_sizes);
|
||||
#endif
|
||||
|
||||
if (error_code != 0) {
|
||||
error_code = no_error;
|
||||
|
||||
109
csrc/cumem_allocator_compat.h
Normal file
109
csrc/cumem_allocator_compat.h
Normal file
@ -0,0 +1,109 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
////////////////////////////////////////
|
||||
// For compatibility with CUDA and ROCm
|
||||
////////////////////////////////////////
|
||||
#include <hip/hip_runtime_api.h>
|
||||
|
||||
extern "C" {
|
||||
#ifndef CUDA_SUCCESS
|
||||
#define CUDA_SUCCESS hipSuccess
|
||||
#endif // CUDA_SUCCESS
|
||||
|
||||
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
|
||||
typedef unsigned long long CUdevice;
|
||||
typedef hipDeviceptr_t CUdeviceptr;
|
||||
typedef hipError_t CUresult;
|
||||
typedef hipCtx_t CUcontext;
|
||||
typedef hipStream_t CUstream;
|
||||
typedef hipMemGenericAllocationHandle_t CUmemGenericAllocationHandle;
|
||||
typedef hipMemAllocationGranularity_flags CUmemAllocationGranularity_flags;
|
||||
typedef hipMemAllocationProp CUmemAllocationProp;
|
||||
typedef hipMemAccessDesc CUmemAccessDesc;
|
||||
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_MEM_ALLOC_GRANULARITY_MINIMUM hipMemAllocationGranularityMinimum
|
||||
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
|
||||
#define CU_MEM_ALLOCATION_COMP_NONE 0x0
|
||||
|
||||
// Error Handling
|
||||
// https://docs.nvidia.com/cuda/archive/11.4.4/cuda-driver-api/group__CUDA__ERROR.html
|
||||
CUresult cuGetErrorString(CUresult hipError, const char** pStr) {
|
||||
*pStr = hipGetErrorString(hipError);
|
||||
return CUDA_SUCCESS;
|
||||
}
|
||||
|
||||
// Context Management
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
|
||||
CUresult cuCtxGetCurrent(CUcontext* ctx) {
|
||||
// This API is deprecated on the AMD platform, only for equivalent cuCtx
|
||||
// driver API on the NVIDIA platform.
|
||||
return hipCtxGetCurrent(ctx);
|
||||
}
|
||||
|
||||
CUresult cuCtxSetCurrent(CUcontext ctx) {
|
||||
// This API is deprecated on the AMD platform, only for equivalent cuCtx
|
||||
// driver API on the NVIDIA platform.
|
||||
return hipCtxSetCurrent(ctx);
|
||||
}
|
||||
|
||||
// Primary Context Management
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html
|
||||
CUresult cuDevicePrimaryCtxRetain(CUcontext* ctx, CUdevice dev) {
|
||||
return hipDevicePrimaryCtxRetain(ctx, dev);
|
||||
}
|
||||
|
||||
// Virtual Memory Management
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html
|
||||
CUresult cuMemAddressFree(CUdeviceptr ptr, size_t size) {
|
||||
return hipMemAddressFree(ptr, size);
|
||||
}
|
||||
|
||||
CUresult cuMemAddressReserve(CUdeviceptr* ptr, size_t size, size_t alignment,
|
||||
CUdeviceptr addr, unsigned long long flags) {
|
||||
return hipMemAddressReserve(ptr, size, alignment, addr, flags);
|
||||
}
|
||||
|
||||
CUresult cuMemCreate(CUmemGenericAllocationHandle* handle, size_t size,
|
||||
const CUmemAllocationProp* prop,
|
||||
unsigned long long flags) {
|
||||
return hipMemCreate(handle, size, prop, flags);
|
||||
}
|
||||
|
||||
CUresult cuMemGetAllocationGranularity(
|
||||
size_t* granularity, const CUmemAllocationProp* prop,
|
||||
CUmemAllocationGranularity_flags option) {
|
||||
return hipMemGetAllocationGranularity(granularity, prop, option);
|
||||
}
|
||||
|
||||
CUresult cuMemMap(CUdeviceptr dptr, size_t size, size_t offset,
|
||||
CUmemGenericAllocationHandle handle,
|
||||
unsigned long long flags) {
|
||||
return hipMemMap(dptr, size, offset, handle, flags);
|
||||
}
|
||||
|
||||
CUresult cuMemRelease(CUmemGenericAllocationHandle handle) {
|
||||
return hipMemRelease(handle);
|
||||
}
|
||||
|
||||
CUresult cuMemSetAccess(CUdeviceptr ptr, size_t size,
|
||||
const CUmemAccessDesc* desc, size_t count) {
|
||||
return hipMemSetAccess(ptr, size, desc, count);
|
||||
}
|
||||
|
||||
CUresult cuMemUnmap(CUdeviceptr ptr, size_t size) {
|
||||
return hipMemUnmap(ptr, size);
|
||||
}
|
||||
} // extern "C"
|
||||
|
||||
#else
|
||||
////////////////////////////////////////
|
||||
// Import CUDA headers for NVIDIA GPUs
|
||||
////////////////////////////////////////
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
@ -88,3 +88,32 @@
|
||||
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
|
||||
switch (VEC_SIZE) { \
|
||||
case 16: { \
|
||||
constexpr int vec_size = 16; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case 8: { \
|
||||
constexpr int vec_size = 8; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case 4: { \
|
||||
constexpr int vec_size = 4; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr int vec_size = 2; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
constexpr int vec_size = 1; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
|
||||
418
csrc/fused_qknorm_rope_kernel.cu
Normal file
418
csrc/fused_qknorm_rope_kernel.cu
Normal file
@ -0,0 +1,418 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cuda_runtime.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "type_convert.cuh"
|
||||
|
||||
#define CHECK_TYPE(x, st) \
|
||||
TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
|
||||
", while ", st, " is expected")
|
||||
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_TH_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#define FINAL_MASK 0xffffffffffffffffULL
|
||||
#else
|
||||
#define FINAL_MASK 0xffffffff
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm::common {
|
||||
template <typename T, int num>
|
||||
struct packed_as;
|
||||
// Specialization for packed_as used in this kernel.
|
||||
template <>
|
||||
struct packed_as<uint, 1> {
|
||||
using type = uint;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_as<uint, 2> {
|
||||
using type = uint2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_as<uint, 4> {
|
||||
using type = uint4;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ __host__ T divUp(T m, T n) {
|
||||
return (m + n - 1) / n;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
namespace tensorrt_llm::kernels {
|
||||
// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation,
|
||||
// with added support for passing the cos_sin_cache as an input.
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
|
||||
|
||||
// Perform per-head QK Norm and RoPE in a single kernel.
|
||||
// scalar_t_in: data type of QKV and RMSNorm weights
|
||||
// scalar_t_cache: data type of cos/sin cache
|
||||
// head_dim: the dimension of each head
|
||||
// interleave: interleave=!is_neox.
|
||||
template <typename scalar_t_in, typename scalar_t_cache, int head_dim,
|
||||
bool interleave>
|
||||
__global__ void fusedQKNormRopeKernel(
|
||||
void* qkv_void, // Combined QKV tensor
|
||||
int const num_heads_q, // Number of query heads
|
||||
int const num_heads_k, // Number of key heads
|
||||
int const num_heads_v, // Number of value heads
|
||||
float const eps, // Epsilon for RMS normalization
|
||||
void const* q_weight_void, // RMSNorm weights for query
|
||||
void const* k_weight_void, // RMSNorm weights for key
|
||||
void const* cos_sin_cache_void, // Pre-computed cos/sin cache
|
||||
int64_t const* position_ids, // Position IDs for RoPE
|
||||
int const num_tokens // Number of tokens
|
||||
) {
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
||||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
|
||||
using Converter = vllm::_typeConvert<scalar_t_in>;
|
||||
static_assert(Converter::exists,
|
||||
"Input QKV data type is not supported for this CUDA "
|
||||
"architecture or toolkit version.");
|
||||
using T_in = typename Converter::hip_type;
|
||||
using T2_in = typename Converter::packed_hip_type;
|
||||
|
||||
using CacheConverter = vllm::_typeConvert<scalar_t_cache>;
|
||||
static_assert(CacheConverter::exists,
|
||||
"Cache data type is not supported for this CUDA architecture "
|
||||
"or toolkit version.");
|
||||
using T_cache = typename CacheConverter::hip_type;
|
||||
|
||||
T_in* qkv = reinterpret_cast<T_in*>(qkv_void);
|
||||
T_in const* q_weight = reinterpret_cast<T_in const*>(q_weight_void);
|
||||
T_in const* k_weight = reinterpret_cast<T_in const*>(k_weight_void);
|
||||
T_cache const* cos_sin_cache =
|
||||
reinterpret_cast<T_cache const*>(cos_sin_cache_void);
|
||||
|
||||
int const warpsPerBlock = blockDim.x / 32;
|
||||
int const warpId = threadIdx.x / 32;
|
||||
int const laneId = threadIdx.x % 32;
|
||||
|
||||
// Calculate global warp index to determine which head/token this warp
|
||||
// processes
|
||||
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
|
||||
|
||||
// Total number of attention heads (Q and K)
|
||||
int const total_qk_heads = num_heads_q + num_heads_k;
|
||||
|
||||
// Determine which token and head type (Q or K) this warp processes
|
||||
int const tokenIdx = globalWarpIdx / total_qk_heads;
|
||||
int const localHeadIdx = globalWarpIdx % total_qk_heads;
|
||||
|
||||
// Skip if this warp is assigned beyond the number of tokens
|
||||
if (tokenIdx >= num_tokens) return;
|
||||
|
||||
bool const isQ = localHeadIdx < num_heads_q;
|
||||
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;
|
||||
|
||||
int const num_heads = num_heads_q + num_heads_k + num_heads_v;
|
||||
|
||||
static_assert(head_dim % (32 * 2) == 0,
|
||||
"head_dim must be divisible by 64 (each warp processes one "
|
||||
"head, and each thread gets even number of "
|
||||
"elements)");
|
||||
constexpr int numElemsPerThread = head_dim / 32;
|
||||
float elements[numElemsPerThread];
|
||||
constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16);
|
||||
static_assert(elemSizeBytes % 4 == 0,
|
||||
"numSizeBytes must be a multiple of 4");
|
||||
constexpr int vecSize =
|
||||
elemSizeBytes /
|
||||
4; // Use packed_as<uint, vecSize> to perform loading/saving.
|
||||
using vec_T = typename tensorrt_llm::common::packed_as<uint, vecSize>::type;
|
||||
|
||||
int offsetWarp; // Offset for the warp
|
||||
if (isQ) {
|
||||
// Q segment: token offset + head offset within Q segment
|
||||
offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim;
|
||||
} else {
|
||||
// K segment: token offset + entire Q segment + head offset within K
|
||||
// segment
|
||||
offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim +
|
||||
headIdx * head_dim;
|
||||
}
|
||||
int offsetThread = offsetWarp + laneId * numElemsPerThread;
|
||||
|
||||
// Sum of squares for RMSNorm
|
||||
float sumOfSquares = 0.0f;
|
||||
|
||||
// Load.
|
||||
{
|
||||
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
|
||||
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_packed_elems; i++) {
|
||||
// Interpret the generic vector chunk as the specific packed type
|
||||
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
|
||||
// Convert to float2 for computation
|
||||
float2 vals = Converter::convert(packed_val);
|
||||
sumOfSquares += vals.x * vals.x;
|
||||
sumOfSquares += vals.y * vals.y;
|
||||
|
||||
elements[2 * i] = vals.x;
|
||||
elements[2 * i + 1] = vals.y;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce sum across warp using the utility function
|
||||
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
|
||||
|
||||
// Compute RMS normalization factor
|
||||
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
|
||||
|
||||
// Normalize elements
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
int dim = laneId * numElemsPerThread + i;
|
||||
float weight = isQ ? Converter::convert(q_weight[dim])
|
||||
: Converter::convert(k_weight[dim]);
|
||||
elements[i] *= rms_rcp * weight;
|
||||
}
|
||||
|
||||
// Apply RoPE to normalized elements
|
||||
float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
|
||||
|
||||
int64_t pos_id = position_ids[tokenIdx];
|
||||
|
||||
// Calculate cache pointer for this position - similar to
|
||||
// pos_encoding_kernels.cu
|
||||
T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim;
|
||||
int const embed_dim = head_dim / 2;
|
||||
T_cache const* cos_ptr = cache_ptr;
|
||||
T_cache const* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
if constexpr (interleave) {
|
||||
// Perform interleaving. Use pre-computed cos/sin values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
||||
int const idx0 = 2 * i;
|
||||
int const idx1 = 2 * i + 1;
|
||||
|
||||
float const val0 = elements[idx0];
|
||||
float const val1 = elements[idx1];
|
||||
|
||||
int const dim_idx = laneId * numElemsPerThread + idx0;
|
||||
int const half_dim = dim_idx / 2;
|
||||
float const cos_val =
|
||||
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float const sin_val =
|
||||
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
|
||||
elements[idx0] = val0 * cos_val - val1 * sin_val;
|
||||
elements[idx1] = val0 * sin_val + val1 * cos_val;
|
||||
}
|
||||
} else {
|
||||
// Before data exchange with in warp, we need to sync.
|
||||
__syncwarp();
|
||||
// Get the data from the other half of the warp. Use pre-computed cos/sin
|
||||
// values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
|
||||
if (laneId < 16) {
|
||||
elements2[i] = -elements2[i];
|
||||
}
|
||||
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
dim_idx = (dim_idx * 2) % head_dim;
|
||||
int half_dim = dim_idx / 2;
|
||||
// Use pre-computed cos/sin from cache
|
||||
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
|
||||
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
|
||||
}
|
||||
// __shfl_xor_sync does not provide memfence. Need to sync again.
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Store.
|
||||
{
|
||||
vec_T vec;
|
||||
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_packed_elems; i++) {
|
||||
// Convert from float2 back to the specific packed type
|
||||
T2_in packed_val = Converter::convert(
|
||||
make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||
// Place it into the generic vector
|
||||
*(reinterpret_cast<T2_in*>(&vec) + i) = packed_val;
|
||||
}
|
||||
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
|
||||
}
|
||||
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Borrowed from
|
||||
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
|
||||
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
||||
if (interleave) { \
|
||||
const bool INTERLEAVE = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
const bool INTERLEAVE = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
template <typename scalar_t_in, typename scalar_t_cache>
|
||||
void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
int const num_heads_q, int const num_heads_k,
|
||||
int const num_heads_v, int const head_dim,
|
||||
float const eps, void const* q_weight,
|
||||
void const* k_weight, void const* cos_sin_cache,
|
||||
bool const interleave, int64_t const* position_ids,
|
||||
cudaStream_t stream) {
|
||||
constexpr int blockSize = 256;
|
||||
|
||||
int const warpsPerBlock = blockSize / 32;
|
||||
int const totalQKHeads = num_heads_q + num_heads_k;
|
||||
int const totalWarps = num_tokens * totalQKHeads;
|
||||
|
||||
int const gridSize = common::divUp(totalWarps, warpsPerBlock);
|
||||
dim3 gridDim(gridSize);
|
||||
dim3 blockDim(blockSize);
|
||||
|
||||
switch (head_dim) {
|
||||
case 64:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 64, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
});
|
||||
break;
|
||||
case 128:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 128, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
});
|
||||
break;
|
||||
case 256:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 256, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
});
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported head dimension for fusedQKNormRope: ", head_dim);
|
||||
}
|
||||
}
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
void fused_qk_norm_rope(
|
||||
torch::Tensor& qkv, // Combined QKV tensor [num_tokens,
|
||||
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
|
||||
int64_t num_heads_q, // Number of query heads
|
||||
int64_t num_heads_k, // Number of key heads
|
||||
int64_t num_heads_v, // Number of value heads
|
||||
int64_t head_dim, // Dimension per head
|
||||
double eps, // Epsilon for RMS normalization
|
||||
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
|
||||
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
|
||||
torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim]
|
||||
bool is_neox, // Whether RoPE is applied in Neox style
|
||||
torch::Tensor& position_ids // Position IDs for RoPE [num_tokens]
|
||||
) {
|
||||
// Input validation
|
||||
CHECK_INPUT(qkv);
|
||||
CHECK_INPUT(position_ids);
|
||||
CHECK_INPUT(q_weight);
|
||||
CHECK_INPUT(k_weight);
|
||||
CHECK_INPUT(cos_sin_cache);
|
||||
CHECK_TYPE(position_ids, torch::kInt64);
|
||||
|
||||
TORCH_CHECK(qkv.dim() == 2,
|
||||
"QKV tensor must be 2D: [num_tokens, "
|
||||
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]");
|
||||
TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]");
|
||||
TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
|
||||
TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
|
||||
TORCH_CHECK(cos_sin_cache.dim() == 2,
|
||||
"Cos/sin cache must be 2D: [max_position, head_dim]");
|
||||
TORCH_CHECK(q_weight.size(0) == head_dim,
|
||||
"Query weights size must match head dimension");
|
||||
TORCH_CHECK(k_weight.size(0) == head_dim,
|
||||
"Key weights size must match head dimension");
|
||||
TORCH_CHECK(cos_sin_cache.size(1) == head_dim,
|
||||
"Cos/sin cache dimension must match head_dim");
|
||||
TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
|
||||
qkv.scalar_type() == k_weight.scalar_type(),
|
||||
"qkv, q_weight and k_weight must have the same dtype");
|
||||
|
||||
int64_t num_tokens = qkv.size(0);
|
||||
TORCH_CHECK(position_ids.size(0) == num_tokens,
|
||||
"Number of tokens in position_ids must match QKV");
|
||||
|
||||
int64_t total_heads = num_heads_q + num_heads_k + num_heads_v;
|
||||
TORCH_CHECK(
|
||||
qkv.size(1) == total_heads * head_dim,
|
||||
"QKV tensor size must match total number of heads and head dimension");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
|
||||
using qkv_scalar_t = scalar_t;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
|
||||
using cache_scalar_t = scalar_t;
|
||||
tensorrt_llm::kernels::launchFusedQKNormRope<qkv_scalar_t,
|
||||
cache_scalar_t>(
|
||||
qkv.data_ptr(), static_cast<int>(num_tokens),
|
||||
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
|
||||
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
|
||||
static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
|
||||
cos_sin_cache.data_ptr(), !is_neox,
|
||||
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
|
||||
stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -10,7 +10,7 @@
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, int VEC_SIZE>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
@ -21,7 +21,6 @@ __global__ void rms_norm_kernel(
|
||||
float variance = 0.0f;
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
|
||||
constexpr int VEC_SIZE = 8;
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
@ -45,10 +44,20 @@ __global__ void rms_norm_kernel(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
scalar_t* out_row = out + blockIdx.x * hidden_size;
|
||||
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
|
||||
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
|
||||
auto* v_out = reinterpret_cast<vec_n_t<scalar_t, VEC_SIZE>*>(out_row);
|
||||
for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) {
|
||||
vec_n_t<scalar_t, VEC_SIZE> dst;
|
||||
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[i];
|
||||
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
float x = static_cast<float>(src1.val[j]);
|
||||
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
|
||||
}
|
||||
v_out[i] = dst;
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,16 +177,24 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
int num_tokens = input_view.numel() / hidden_size;
|
||||
int64_t input_stride = input_view.stride(-2);
|
||||
|
||||
// For large num_tokens, use smaller blocks to increase SM concurrency.
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input_view.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
const int calculated_vec_size =
|
||||
std::gcd(16 / sizeof(scalar_t), hidden_size);
|
||||
const int block_size =
|
||||
std::min(hidden_size / calculated_vec_size, max_block_size);
|
||||
dim3 block(block_size);
|
||||
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
|
||||
vllm::rms_norm_kernel<scalar_t, vec_size><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
template <typename scalar_t, typename fp8_type, int VEC_SIZE>
|
||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
@ -31,7 +31,6 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
|
||||
constexpr int VEC_SIZE = 8;
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
@ -58,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
// invert scale to avoid division
|
||||
float const scale_inv = 1.0f / *scale;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
|
||||
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
|
||||
for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) {
|
||||
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[idx];
|
||||
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[idx];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
float x = static_cast<float>(src1.val[j]);
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
|
||||
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
int input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
// For large num_tokens, use smaller blocks to increase SM concurrency.
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
const int calculated_vec_size =
|
||||
std::gcd(16 / sizeof(scalar_t), hidden_size);
|
||||
const int block_size =
|
||||
std::min(hidden_size / calculated_vec_size, max_block_size);
|
||||
dim3 block(block_size);
|
||||
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t,
|
||||
vec_size>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -92,6 +92,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
|
||||
int64_t num_heads_k, int64_t num_heads_v,
|
||||
int64_t head_dim, double eps, torch::Tensor& q_weight,
|
||||
torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
|
||||
bool is_neox, torch::Tensor& position_ids);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
|
||||
@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::GMMA::Major::MN, cute::GMMA::Major::K>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
@ -175,6 +175,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
// Function for fused QK Norm and RoPE
|
||||
ops.def(
|
||||
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
||||
"int num_heads_k, int num_heads_v, int head_dim, float eps, "
|
||||
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
||||
"bool is_neox, Tensor position_ids) -> ()");
|
||||
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
|
||||
@ -29,6 +29,22 @@ struct _typeConvert {
|
||||
static constexpr bool exists = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _typeConvert<float> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = float;
|
||||
using packed_hip_type = float2;
|
||||
using packed_hip_type4 = float4; // For 128-bit vectorization
|
||||
|
||||
__device__ static __forceinline__ float convert(hip_type x) { return x; }
|
||||
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
|
||||
return x;
|
||||
}
|
||||
__device__ static __forceinline__ float4 convert(packed_hip_type4 x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template <>
|
||||
@ -37,41 +53,44 @@ struct _typeConvert<c10::Half> {
|
||||
using hip_type = __half;
|
||||
using packed_hip_type = __half2;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
__device__ static __forceinline__ float convert(hip_type x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
__device__ static __forceinline__ hip_type convert(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
__device__ static __forceinline__ packed_hip_type convert(float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||
// ROCm 7.0+ supports bfloat16
|
||||
template <>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __nv_bfloat16;
|
||||
using packed_hip_type = __nv_bfloat162;
|
||||
|
||||
__device__ static inline float convert(hip_type x) {
|
||||
__device__ static __forceinline__ float convert(hip_type x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
__device__ static __forceinline__ hip_type convert(float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
__device__ static __forceinline__ packed_hip_type convert(float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
};
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
|
||||
// defined(USE_ROCM)
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||
// 12000))
|
||||
|
||||
@ -95,10 +114,15 @@ struct alignas(16) _f16Vec {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp += T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
if constexpr (std::is_same_v<T2, float2>) {
|
||||
data[i] += other.data[i];
|
||||
data[i + 1] += other.data[i + 1];
|
||||
} else {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp += T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
@ -111,10 +135,15 @@ struct alignas(16) _f16Vec {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp *= T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
if constexpr (std::is_same_v<T2, float2>) {
|
||||
data[i] *= other.data[i];
|
||||
data[i + 1] *= other.data[i + 1];
|
||||
} else {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp *= T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
# VLLM_CPU_DISABLE_AVX512=false (default)|true
|
||||
# VLLM_CPU_AVX512BF16=false (default)|true
|
||||
# VLLM_CPU_AVX512VNNI=false (default)|true
|
||||
# VLLM_CPU_AMXBF16=false (default)|true
|
||||
#
|
||||
|
||||
######################### COMMON BASE IMAGE #########################
|
||||
@ -92,6 +93,9 @@ ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
|
||||
# Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ...
|
||||
ARG VLLM_CPU_AVX512VNNI=0
|
||||
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
|
||||
# Support for building with AMXBF16 ISA: docker build --build-arg VLLM_CPU_AMXBF16="true" ...
|
||||
ARG VLLM_CPU_AMXBF16=0
|
||||
ENV VLLM_CPU_AMXBF16=${VLLM_CPU_AMXBF16}
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
|
||||
BIN
docs/assets/features/disagg_encoder/disagg_encoder_flow.png
Normal file
BIN
docs/assets/features/disagg_encoder/disagg_encoder_flow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
@ -1,7 +1,16 @@
|
||||
# Meetups
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
We host regular meetups around the world. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights.
|
||||
|
||||
## Upcoming Meetups
|
||||
|
||||
Stay tuned for upcoming meetups! Follow us on [Twitter/X](https://x.com/vllm_project), join our [Slack](https://slack.vllm.ai), and follow vLLM on [Luma](https://luma.com/vLLM-Meetups) to get notified about new events.
|
||||
|
||||
## Past Meetups
|
||||
|
||||
Below you'll find slides and recordings from our previous meetups:
|
||||
|
||||
- [vLLM Zurich Meetup](https://luma.com/0gls27kb), November 6th 2025. [[Slides]](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) [[Recording]](https://www.youtube.com/watch?v=6m6ZE6yVEDI)
|
||||
- [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w), November 1st 2025. [[Slides]](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link)
|
||||
- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg), October 25th 2025. [[Slides]](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6)
|
||||
- [vLLM Toronto Meetup](https://luma.com/e80e0ymm), September 25th 2025. [[Slides]](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing)
|
||||
@ -25,4 +34,12 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share
|
||||
- [The second vLLM meetup](https://lu.ma/ygxbpzhl), with IBM Research, January 31st 2024. [[Slides]](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing) [[Video (vLLM Update)]](https://youtu.be/Y0C-DUvEnZQ) [[Video (IBM Research & torch.compile)]](https://youtu.be/m0dMtFLI-dg)
|
||||
- [The first vLLM meetup](https://lu.ma/first-vllm-meetup), with a16z, October 5th 2023. [[Slides]](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing)
|
||||
|
||||
We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu).
|
||||
## Get Involved
|
||||
|
||||
**Want to host or speak at a vLLM meetup?** We're always looking for speakers and sponsors for our meetups. Whether you want to:
|
||||
|
||||
- Share your vLLM feature, use case, project extension, or deployment experience
|
||||
- Host a meetup in your city
|
||||
- Sponsor an event
|
||||
|
||||
Please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu).
|
||||
|
||||
@ -56,13 +56,13 @@ The initialization code should look like this:
|
||||
|
||||
### Computation Code
|
||||
|
||||
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
|
||||
- Add a `embed_input_ids` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
|
||||
|
||||
```python
|
||||
class MyModel(nn.Module):
|
||||
...
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ Further update the model as follows:
|
||||
|
||||
More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it.
|
||||
|
||||
- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
|
||||
??? code
|
||||
|
||||
@ -49,7 +49,7 @@ Further update the model as follows:
|
||||
image_features = self.vision_encoder(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
def embed_multimodal(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> MultiModalEmbeddings | None:
|
||||
@ -69,7 +69,7 @@ Further update the model as follows:
|
||||
!!! note
|
||||
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
|
||||
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
|
||||
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
|
||||
This logic can be found at [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids].
|
||||
|
||||
You may override this method if additional logic is required for your model when merging embeddings.
|
||||
|
||||
|
||||
@ -177,8 +177,9 @@ The following table lists backends that support full CUDA Graphs at the time of
|
||||
| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good |
|
||||
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches |
|
||||
| AITER FlashAttention | `UNIFORM_BATCH`| |
|
||||
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
|
||||
| FlashMLA | `UNIFORM_BATCH` | |
|
||||
| FlashInferMLA | `UNIFORM_BATCH` | |
|
||||
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
@ -218,16 +219,6 @@ outputs = model.generate(
|
||||
)
|
||||
```
|
||||
|
||||
### Migration from legacy flags
|
||||
|
||||
Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`:
|
||||
|
||||
* `use_cudagraph=False` → `NONE`.
|
||||
* `use_cudagraph=True` and `full_cuda_graph=False` → `PIECEWISE`.
|
||||
* `full_cuda_graph=True` → directly set `FULL` and rely on the graceful fallback policy.
|
||||
|
||||
As they are deprecated and will be removed in the next major or minor release, i.e., v0.11.0 or v1.0.0, we recommend using cudagraph_mode instead.
|
||||
|
||||
### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism)
|
||||
|
||||
Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.
|
||||
|
||||
75
docs/features/disagg_encoder.md
Normal file
75
docs/features/disagg_encoder.md
Normal file
@ -0,0 +1,75 @@
|
||||
# Disaggregated Encoder
|
||||
|
||||
A **disaggregated encoder** runs the vision-encoder stage of a multimodal LLM in a process that is separate from the pre-fill / decoder stage. Deploying these two stages in independent vLLM instances brings three practical benefits:
|
||||
|
||||
1. **Independent, fine-grained scaling**
|
||||
2. **Lower time-to-first-token (TTFT)**
|
||||
3. **Cross-process reuse and caching of encoder outputs**
|
||||
|
||||
Design doc: <https://docs.google.com/document/d/1aed8KtC6XkXtdoV87pWT0a8OJlZ-CpnuLLzmR8l9BAE>
|
||||
|
||||
---
|
||||
|
||||
## 1 Motivation
|
||||
|
||||
### 1. Independent, fine-grained scaling
|
||||
|
||||
* Vision encoders are lightweight, while language models are orders of magnitude larger.
|
||||
* The language model can be parallelised without affecting the encoder fleet.
|
||||
* Encoder nodes can be added or removed independently.
|
||||
|
||||
### 2. Lower time-to-first-token (TTFT)
|
||||
|
||||
* Language-only requests bypass the vision encoder entirely.
|
||||
* Encoder output is injected only at required attention layers, shortening the pre-fill critical path.
|
||||
|
||||
### 3. Cross-process reuse and caching
|
||||
|
||||
* In-process encoders confine reuse to a single worker.
|
||||
* A remote, shared cache lets any worker retrieve existing embeddings, eliminating redundant computation.
|
||||
|
||||
---
|
||||
|
||||
## 2 Usage Example
|
||||
|
||||
The current reference pathway is **SharedStorageConnector**.
|
||||
Below ready-to-run scripts shows the workflow:
|
||||
|
||||
1 Encoder instance + 1 PD instance:
|
||||
`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_encoder_example.sh`
|
||||
|
||||
1 Encoder instance + 1 Prefill instance + 1 Decode instance:
|
||||
`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_epd_example.sh`
|
||||
|
||||
---
|
||||
|
||||
## 3 Test Script
|
||||
|
||||
Please refer to the directories `tests/v1/ec_connector`
|
||||
|
||||
## 4 Development
|
||||
|
||||
Disaggregated encoding is implemented by running two parts:
|
||||
|
||||
* **Encoder instance** – a vLLM instance to performs vision encoding.
|
||||
* **Prefill/Decode (PD) instance(s)** – runs language pre-fill and decode.
|
||||
* PD can be in either a single normal instance with `disagg_encoder_example.sh` (E->PD) or in disaggregated instances with `disagg_epd_example.sh` (E->P->D)
|
||||
|
||||
A connector transfers encoder-cache (EC) embeddings from the encoder instance to the PD instance.
|
||||
All related code is under `vllm/distributed/ec_transfer`.
|
||||
|
||||
### Key abstractions
|
||||
|
||||
* **ECConnector** – interface for retrieving EC caches produced by the encoder.
|
||||
* *Scheduler role* – checks cache existence and schedules loads.
|
||||
* *Worker role* – loads the embeddings into memory.
|
||||
|
||||
Here is a figure illustrating disaggregate encoder flow:
|
||||
|
||||

|
||||
|
||||
For the PD disaggregation part, the Prefill instance receive cache exactly the same as the disaggregate encoder flow above. Prefill instance executes 1 step (prefill -> 1 token output) and then transfer KV cache to the Decode instance for the remaining execution. The KV transfer part purely happens after the execute of the PDinstance.
|
||||
|
||||
`docs/features/disagg_prefill.md` shows the brief idea about the disaggregated prefill (v0)
|
||||
|
||||
We create the example setup with the **NixlConnector** from `vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py` and referred to the `tests/v1/kv_connector/nixl_integration/toy_proxy_server.py` to facilitate the kv transfer between P and D;
|
||||
@ -281,4 +281,36 @@ python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
|
||||
--group_size 32
|
||||
```
|
||||
|
||||
The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights.
|
||||
The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations.
|
||||
|
||||
## Using Quark Quantized layerwise Auto Mixed Precision (AMP) Models
|
||||
|
||||
vLLM also supports loading layerwise mixed precision model quantized using AMD Quark. Currently, mixed scheme of {MXFP4, FP8} is supported, where FP8 here denotes for FP8 per-tensor scheme. More mixed precision schemes are planned to be supported in a near future, including
|
||||
|
||||
- Unquantized Linear and/or MoE layer(s) as an option for each layer, i.e., mixed of {MXFP4, FP8, BF16/FP16}
|
||||
- MXFP6 quantization extension, i.e., {MXFP4, MXFP6, FP8, BF16/FP16}
|
||||
|
||||
Although one can maximize serving throughput using the lowest precision supported on a given device (e.g. MXFP4 for AMD Instinct MI355, FP8 for AMD Instinct MI300), these aggressive schemes can be detrimental to accuracy recovering from quantization on target tasks. Mixed precision allows to strike a balance between maximizing accuracy and throughput.
|
||||
|
||||
There are two steps to generate and deploy a mixed precision model quantized with AMD Quark, as shown below.
|
||||
|
||||
### 1. Quantize a model using mixed precision in AMD Quark
|
||||
|
||||
Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later.
|
||||
|
||||
As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are:
|
||||
|
||||
- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
|
||||
- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
|
||||
- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
|
||||
|
||||
### 2. inference the quantized mixed precision model in vLLM
|
||||
|
||||
Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follow:
|
||||
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8,tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
@ -11,7 +11,7 @@ Key benefits:
|
||||
- **Fine-grained control**: Optionally wake up only model weights or KV cache to avoid OOM during weight updates.
|
||||
|
||||
!!! note
|
||||
This feature is only supported on CUDA platform.
|
||||
This feature is now supported on CUDA and ROCm platform.
|
||||
|
||||
!!! note
|
||||
For more information, see this [Blog Post](https://blog.vllm.ai/2025/10/26/sleep-mode.html).
|
||||
@ -116,3 +116,7 @@ curl -X POST 'http://localhost:8000/wake_up?tags=kv_cache'
|
||||
|
||||
!!! note
|
||||
These endpoints are only available when passing `VLLM_SERVER_DEV_MODE=1`.
|
||||
|
||||
## Limitation
|
||||
|
||||
On ROCm, the virtual memory allocation on ROCm is done through chunked memory allocation. You can control the chunk size through `VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE` (in MB). The default value is set at 256MB. The larger the chunk size the faster the performance. However, setting it too large will cause OOM. So if you encounter OOM when using sleep mode. Try reducing the chunk size. It is recommended to define the chunk size as a power of 2.
|
||||
|
||||
@ -93,7 +93,7 @@ Currently, there are no pre-built CPU wheels.
|
||||
|
||||
## Related runtime environment variables
|
||||
|
||||
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
|
||||
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM to run more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
|
||||
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable.
|
||||
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
|
||||
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
|
||||
@ -128,7 +128,7 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe
|
||||
|
||||
### How to decide `VLLM_CPU_OMP_THREADS_BIND`?
|
||||
|
||||
- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to a same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If have any performance problems or unexpected binding behaviours, please try to bind threads as following.
|
||||
- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to the same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If you have any performance problems or unexpected binding behaviours, please try to bind threads as following.
|
||||
|
||||
- On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
|
||||
|
||||
@ -156,12 +156,12 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe
|
||||
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
|
||||
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
|
||||
|
||||
# On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
|
||||
# On this platform, it is recommended to only bind openMP threads on logical CPU cores 0-7 or 8-15
|
||||
$ export VLLM_CPU_OMP_THREADS_BIND=0-7
|
||||
$ python examples/offline_inference/basic/basic.py
|
||||
```
|
||||
|
||||
- When deploy vLLM CPU backend on a multi-socket machine with NUMA and enable tensor parallel or pipeline parallel, each NUMA node is treated as a TP/PP rank. So be aware to set CPU cores of a single rank on a same NUMA node to avoid cross NUMA node memory access.
|
||||
- When deploying vLLM CPU backend on a multi-socket machine with NUMA and enable tensor parallel or pipeline parallel, each NUMA node is treated as a TP/PP rank. So be aware to set CPU cores of a single rank on the same NUMA node to avoid cross NUMA node memory access.
|
||||
|
||||
### How to decide `VLLM_CPU_KVCACHE_SPACE`?
|
||||
|
||||
@ -171,7 +171,9 @@ This value is 4GB by default. Larger space can support more concurrent requests,
|
||||
|
||||
First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
|
||||
|
||||
Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
|
||||
Use multiples of 32 as `--block-size`, which is 128 by default.
|
||||
|
||||
Inference batch size is an important parameter for the performance. A larger batch usually provides higher throughput, a smaller batch provides lower latency. Tuning the max batch size starting from the default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
|
||||
|
||||
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
|
||||
- Offline Inference: `4096 * world_size`
|
||||
@ -192,8 +194,8 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`?
|
||||
|
||||
- Both of them require `amx` CPU flag.
|
||||
- `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models
|
||||
- `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios.
|
||||
- `VLLM_CPU_MOE_PREPACK` can provide better performance for MoE models
|
||||
- `VLLM_CPU_SGL_KERNEL` can provide better performance for MoE models and small-batch scenarios.
|
||||
|
||||
### Why do I see `get_mempolicy: Operation not permitted` when running in Docker?
|
||||
|
||||
|
||||
@ -75,7 +75,12 @@ This section details the necessary modifications to make to a Transformers compa
|
||||
To make your model compatible with the Transformers backend, it needs:
|
||||
|
||||
1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`.
|
||||
1. If your model is encoder-only, you must also add `is_causal = False` to `MyAttention`.
|
||||
- If your model is encoder-only:
|
||||
1. Add `is_causal = False` to `MyAttention`.
|
||||
- If your model is mixture-of-experts (MoE):
|
||||
1. Your sparse MoE block must have an attribute called `experts`.
|
||||
2. The class of `experts` (`MyExperts`) must inherit from `nn.ModuleList`.
|
||||
3. `MyExperts.forward` must accept `hidden_states`, `top_k_index`, `top_k_weights`.
|
||||
2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention.
|
||||
3. `MyModel` must contain `_supports_attention_backend = True`.
|
||||
|
||||
@ -102,6 +107,23 @@ class MyAttention(nn.Module):
|
||||
)
|
||||
...
|
||||
|
||||
# Only do this for mixture-of-experts models
|
||||
class MyExperts(nn.ModuleList):
|
||||
def forward(self, hidden_states, top_k_index, top_k_weights):
|
||||
...
|
||||
|
||||
# Only do this for mixture-of-experts models
|
||||
class MySparseMoEBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
...
|
||||
self.experts = MyExperts(config)
|
||||
...
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
...
|
||||
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
|
||||
...
|
||||
|
||||
class MyModel(PreTrainedModel):
|
||||
_supports_attention_backend = True
|
||||
```
|
||||
|
||||
@ -77,11 +77,11 @@ In addition, we have the following custom APIs:
|
||||
|
||||
In order for the language model to support chat protocol, vLLM requires the model to include
|
||||
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
|
||||
specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
|
||||
specifies how roles, messages, and other chat-specific tokens are encoded in the input.
|
||||
|
||||
An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models)
|
||||
|
||||
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
|
||||
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those models,
|
||||
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
|
||||
template, or the template in string form. Without a chat template, the server will not be able to process chat
|
||||
and all chat requests will error.
|
||||
|
||||
@ -4,8 +4,7 @@
|
||||
This file demonstrates the usage of text generation with an LLM model,
|
||||
comparing the performance with and without speculative decoding.
|
||||
|
||||
Note that still not support `v1`:
|
||||
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
|
||||
Note that this example is out of date and not supported in vLLM v1.
|
||||
"""
|
||||
|
||||
import gc
|
||||
|
||||
@ -11,12 +11,10 @@ python examples/offline_inference/qwen2_5_omni/only_thinker.py \
|
||||
|
||||
# Read vision and audio inputs from a single video file
|
||||
# NOTE: V1 engine does not support interleaved modalities yet.
|
||||
VLLM_USE_V1=0 \
|
||||
python examples/offline_inference/qwen2_5_omni/only_thinker.py \
|
||||
-q use_audio_in_video
|
||||
|
||||
# Multiple audios
|
||||
VLLM_USE_V1=0 \
|
||||
python examples/offline_inference/qwen2_5_omni/only_thinker.py \
|
||||
-q multi_audios
|
||||
```
|
||||
|
||||
@ -7,7 +7,6 @@ with the correct prompt format on Qwen2.5-Omni (thinker only).
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
@ -72,11 +71,7 @@ def get_use_audio_in_video_query() -> QueryResult:
|
||||
)
|
||||
asset = VideoAsset(name="baby_reading", num_frames=16)
|
||||
audio = asset.get_audio(sampling_rate=16000)
|
||||
assert not envs.VLLM_USE_V1, (
|
||||
"V1 does not support use_audio_in_video. "
|
||||
"Please launch this example with "
|
||||
"`VLLM_USE_V1=0`."
|
||||
)
|
||||
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
|
||||
119
examples/online_serving/disaggregated_encoder/README.md
Normal file
119
examples/online_serving/disaggregated_encoder/README.md
Normal file
@ -0,0 +1,119 @@
|
||||
# Disaggregated Encoder
|
||||
|
||||
These example scripts that demonstrate the disaggregated encoder (EPD) features of vLLM.
|
||||
|
||||
For a detailed explanation of the EPD features, please refer to the [Disaggregated Encoder Feature Documentation](../../../docs/features/disagg_encoder.md).
|
||||
|
||||
## Files
|
||||
|
||||
- `disagg_epd_proxy.py` - Proxy script that demonstrates the XeYpZd setup (X encode instances, Y prefill instances, Z decode instances). Currently stable for the 1e1p1d configuration.
|
||||
|
||||
- `disagg_1e1p1d_example.sh` - Sets up the 1e1p1d configuration, runs the VisionArena benchmark, and processes a single request with a local image.
|
||||
|
||||
- `disagg_1e1pd_example.sh` - Sets up the 1e1pd configuration, runs the VisionArena benchmark, and processes a single request with a local image.
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
```bash
|
||||
# Use specific GPUs
|
||||
GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash disagg_1e1p1d_example.sh
|
||||
|
||||
# Use specific ports
|
||||
ENDPOINT_PORT=10001 bash disagg_1e1p1d_example.sh
|
||||
|
||||
# Use specific model
|
||||
MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash disagg_1e1p1d_example.sh
|
||||
|
||||
# Use specific storage path
|
||||
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh
|
||||
```
|
||||
|
||||
## Encoder Instances
|
||||
|
||||
Encoder engines should be launched with the following flags:
|
||||
|
||||
- `--enforce-eager` **(required)** – The current EPD implementation is only compatible with encoder instances running in this mode.
|
||||
|
||||
- `--no-enable-prefix-caching` **(required)** – Encoder instances do not consume KV cache; prefix caching is disabled to avoid conflicts with other features.
|
||||
|
||||
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
|
||||
|
||||
## Local media inputs
|
||||
|
||||
To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance:
|
||||
|
||||
```bash
|
||||
--allowed-local-media-path $MEDIA_PATH
|
||||
```
|
||||
|
||||
The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url": "file://'"$MEDIA_PATH_FILENAME"'}``` as multimodal inputs. Each URI is passed unchanged from the `disagg_encoder_proxy` to the encoder instance so that the encoder can load the media locally.
|
||||
|
||||
## EC connector and KV transfer
|
||||
|
||||
The `ECSharedStorageConnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration:
|
||||
|
||||
```bash
|
||||
# Add to encoder instance:
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}'
|
||||
|
||||
# Add to prefill/prefill+decode instance:
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
`$EC_SHARED_STORAGE_PATH` is the path where the EC connector temporarily stores the cache.
|
||||
|
||||
If you enable prefill instance (`--prefill-servers-urls` not disabled), you will need --kv-transfer-config to facilitate the PD disaggregation. Currently, we use the `NixlConnector` for this purpose. Refer to `tests/v1/kv_connector/nixl_integration` for more example codes on PD disaggregation with Nixl.
|
||||
|
||||
```bash
|
||||
# Add to prefill instance:
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_producer"
|
||||
}'
|
||||
|
||||
# Add to decode instance:
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer"
|
||||
}'
|
||||
```
|
||||
|
||||
## Proxy Instance Flags (`disagg_epd_proxy.py`)
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--encode-servers-urls` | Comma-separated list of encoder endpoints. Every multimodal item extracted from the request is fanned out to one of these URLs in a round-robin fashion. |
|
||||
| `--prefill-servers-urls` | Comma-separated list of prefill endpoints. Set to `disable`, `none`, or `""` to skip the dedicated prefill phase and run E+PD (encoder + combined prefill/decode). |
|
||||
| `--decode-servers-urls` | Comma-separated list of decode endpoints. Non-stream and stream paths both round-robin over this list. |
|
||||
| `--host`, `--port` | Bind address for the proxy itself (defaults: `0.0.0.0:8000`). |
|
||||
|
||||
Example usage:
|
||||
For E + PD setup:
|
||||
|
||||
```bash
|
||||
$ python disagg_encoder_proxy.py \
|
||||
--encode-servers-urls "http://e1:8001,http://e2:8002" \
|
||||
--prefill-servers-urls "disable" \
|
||||
--decode-servers-urls "http://pd1:8003,http://pd2:8004"
|
||||
```
|
||||
|
||||
For E + P + D setup:
|
||||
|
||||
```bash
|
||||
$ python disagg_encoder_proxy.py \
|
||||
--encode-servers-urls "http://e1:8001,http://e2:8001" \
|
||||
--prefill-servers-urls "http://p1:8003,http://p2:8004" \
|
||||
--decode-servers-urls "http://d1:8005,http://d2:8006"
|
||||
```
|
||||
@ -0,0 +1,221 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
###############################################################################
|
||||
# Configuration -- override via env before running
|
||||
###############################################################################
|
||||
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
|
||||
LOG_PATH="${LOG_PATH:-./logs}"
|
||||
mkdir -p $LOG_PATH
|
||||
|
||||
ENCODE_PORT="${ENCODE_PORT:-19534}"
|
||||
PREFILL_PORT="${PREFILL_PORT:-19535}"
|
||||
DECODE_PORT="${DECODE_PORT:-19536}"
|
||||
PROXY_PORT="${PROXY_PORT:-10001}"
|
||||
|
||||
GPU_E="${GPU_E:-2}"
|
||||
GPU_P="${GPU_P:-2}"
|
||||
GPU_D="${GPU_D:-3}"
|
||||
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||
|
||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||
|
||||
export UCX_TLS=all
|
||||
export UCX_NET_DEVICES=all
|
||||
|
||||
###############################################################################
|
||||
# Helpers
|
||||
###############################################################################
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
START_TIME=$(date +"%Y%m%d_%H%M%S")
|
||||
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
|
||||
P_LOG=$LOG_PATH/p_${START_TIME}.log
|
||||
D_LOG=$LOG_PATH/d_${START_TIME}.log
|
||||
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout "$TIMEOUT_SECONDS" bash -c "
|
||||
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Stopping everything…"
|
||||
trap - INT TERM USR1 # prevent re-entrancy
|
||||
|
||||
# Kill all tracked PIDs
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Killing process $pid"
|
||||
kill "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Wait a moment for graceful shutdown
|
||||
sleep 2
|
||||
|
||||
# Force kill any remaining processes
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Force killing process $pid"
|
||||
kill -9 "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Kill the entire process group as backup
|
||||
kill -- -$$ 2>/dev/null
|
||||
|
||||
echo "All processes stopped."
|
||||
exit 0
|
||||
}
|
||||
|
||||
trap cleanup INT
|
||||
trap cleanup USR1
|
||||
trap cleanup TERM
|
||||
|
||||
# clear previous cache
|
||||
echo "remove previous ec cache folder"
|
||||
rm -rf $EC_SHARED_STORAGE_PATH
|
||||
|
||||
echo "make ec cache folder"
|
||||
mkdir -p $EC_SHARED_STORAGE_PATH
|
||||
|
||||
###############################################################################
|
||||
# Encoder worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--port "$ENCODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
>"${ENC_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Prefill worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_P" \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
||||
vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$PREFILL_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_producer"
|
||||
}' \
|
||||
>"${P_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Decode worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_D" \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
||||
vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$DECODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer"
|
||||
}' \
|
||||
>"${D_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for workers
|
||||
wait_for_server $ENCODE_PORT
|
||||
wait_for_server $PREFILL_PORT
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
###############################################################################
|
||||
# Proxy
|
||||
###############################################################################
|
||||
python disagg_epd_proxy.py \
|
||||
--host "0.0.0.0" \
|
||||
--port "$PROXY_PORT" \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "http://localhost:$PREFILL_PORT" \
|
||||
--decode-servers-urls "http://localhost:$DECODE_PORT" \
|
||||
>"${PROXY_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
wait_for_server $PROXY_PORT
|
||||
echo "All services are up!"
|
||||
|
||||
###############################################################################
|
||||
# Benchmark
|
||||
###############################################################################
|
||||
echo "Running benchmark (stream)..."
|
||||
vllm bench serve \
|
||||
--model $MODEL \
|
||||
--backend openai-chat \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--seed 0 \
|
||||
--num-prompts $NUM_PROMPTS \
|
||||
--port $PROXY_PORT
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Single request with local image
|
||||
###############################################################################
|
||||
echo "Running single request with local image (non-stream)..."
|
||||
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "'${MODEL}'",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}},
|
||||
{"type": "text", "text": "What is in this image?"}
|
||||
]}
|
||||
]
|
||||
}'
|
||||
|
||||
|
||||
# cleanup
|
||||
echo "cleanup..."
|
||||
cleanup
|
||||
@ -0,0 +1,186 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
###############################################################################
|
||||
# Configuration -- override via env before running
|
||||
###############################################################################
|
||||
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
|
||||
LOG_PATH="${LOG_PATH:-./logs}"
|
||||
mkdir -p $LOG_PATH
|
||||
|
||||
ENCODE_PORT="${ENCODE_PORT:-19534}"
|
||||
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}"
|
||||
PROXY_PORT="${PROXY_PORT:-10001}"
|
||||
|
||||
GPU_E="${GPU_E:-0}"
|
||||
GPU_PD="${GPU_PD:-1}"
|
||||
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||
|
||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||
|
||||
###############################################################################
|
||||
# Helpers
|
||||
###############################################################################
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
START_TIME=$(date +"%Y%m%d_%H%M%S")
|
||||
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
|
||||
PD_LOG=$LOG_PATH/pd_${START_TIME}.log
|
||||
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout "$TIMEOUT_SECONDS" bash -c "
|
||||
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Stopping everything…"
|
||||
trap - INT TERM USR1 # prevent re-entrancy
|
||||
|
||||
# Kill all tracked PIDs
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Killing process $pid"
|
||||
kill "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Wait a moment for graceful shutdown
|
||||
sleep 2
|
||||
|
||||
# Force kill any remaining processes
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Force killing process $pid"
|
||||
kill -9 "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Kill the entire process group as backup
|
||||
kill -- -$$ 2>/dev/null
|
||||
|
||||
echo "All processes stopped."
|
||||
exit 0
|
||||
}
|
||||
|
||||
trap cleanup INT
|
||||
trap cleanup USR1
|
||||
trap cleanup TERM
|
||||
|
||||
# clear previous cache
|
||||
echo "remove previous ec cache folder"
|
||||
rm -rf $EC_SHARED_STORAGE_PATH
|
||||
|
||||
echo "make ec cache folder"
|
||||
mkdir -p $EC_SHARED_STORAGE_PATH
|
||||
|
||||
###############################################################################
|
||||
# Encoder worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--port "$ENCODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
>"${ENC_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Prefill+Decode worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$PREFILL_DECODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
>"${PD_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for workers
|
||||
wait_for_server $ENCODE_PORT
|
||||
wait_for_server $PREFILL_DECODE_PORT
|
||||
|
||||
###############################################################################
|
||||
# Proxy
|
||||
###############################################################################
|
||||
python disagg_epd_proxy.py \
|
||||
--host "0.0.0.0" \
|
||||
--port "$PROXY_PORT" \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "disable" \
|
||||
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
|
||||
>"${PROXY_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
wait_for_server $PROXY_PORT
|
||||
echo "All services are up!"
|
||||
|
||||
###############################################################################
|
||||
# Benchmark
|
||||
###############################################################################
|
||||
echo "Running benchmark (stream)..."
|
||||
vllm bench serve \
|
||||
--model $MODEL \
|
||||
--backend openai-chat \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--seed 0 \
|
||||
--num-prompts $NUM_PROMPTS \
|
||||
--port $PROXY_PORT
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Single request with local image
|
||||
###############################################################################
|
||||
echo "Running single request with local image (non-stream)..."
|
||||
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "'${MODEL}'",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}},
|
||||
{"type": "text", "text": "What is in this image?"}
|
||||
]}
|
||||
]
|
||||
}'
|
||||
|
||||
|
||||
# cleanup
|
||||
echo "cleanup..."
|
||||
cleanup
|
||||
@ -0,0 +1,606 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
disagg_encoder_proxy.py
|
||||
|
||||
Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two
|
||||
clusters:
|
||||
• encode (multimodal feature extraction)
|
||||
• decode (language-model inference)
|
||||
|
||||
For MM input we:
|
||||
1. Extract *every* image/audio item.
|
||||
2. Fire N concurrent requests to the encoder cluster
|
||||
(one request per item, with **all text removed**).
|
||||
3. Wait for all of them to succeed.
|
||||
4. Forward the *original* request to a decode server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
###############################################################################
|
||||
# FastAPI app & global state
|
||||
###############################################################################
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("proxy")
|
||||
|
||||
app = FastAPI()
|
||||
encode_session: aiohttp.ClientSession | None = None
|
||||
prefill_session: aiohttp.ClientSession | None = None
|
||||
decode_session: aiohttp.ClientSession | None = None
|
||||
|
||||
###############################################################################
|
||||
# Utils
|
||||
###############################################################################
|
||||
|
||||
|
||||
MM_TYPES = {"image_url", "audio_url", "input_audio"}
|
||||
|
||||
|
||||
def extract_mm_items(request_data: dict) -> list[dict]:
|
||||
"""
|
||||
Return *all* image/audio items that appear anywhere in `messages`.
|
||||
|
||||
Each returned dict looks like:
|
||||
{ "type": "image_url", "image_url": {...} }
|
||||
"""
|
||||
items: list[dict] = []
|
||||
for msg in request_data.get("messages", []):
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if item.get("type") in MM_TYPES:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
async def fanout_encoder_primer(
|
||||
orig_request: dict,
|
||||
e_urls: list[str],
|
||||
req_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
1. Build one request *per MM item* with all text removed.
|
||||
2. Send them concurrently to the encode cluster.
|
||||
3. Raise if any of them fails.
|
||||
"""
|
||||
logger.info("[%s] Processing multimodal items...", req_id)
|
||||
|
||||
mm_items = extract_mm_items(orig_request)
|
||||
if not mm_items:
|
||||
logger.info("[%s] No multimodal items, skipping encoder", req_id)
|
||||
return # nothing to do
|
||||
|
||||
logger.info("[%s] got %d multimodal items...", req_id, len(mm_items))
|
||||
|
||||
tasks = []
|
||||
|
||||
# Round-robin over encode servers to distribute load a bit
|
||||
url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items)))
|
||||
|
||||
for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)):
|
||||
# Derive a *child* request id: <parent>:<index>:<random-short>
|
||||
child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}"
|
||||
headers = {"x-request-id": child_req_id}
|
||||
|
||||
encoder_req = {
|
||||
# You *may* need to keep additional fields
|
||||
"model": orig_request.get("model"),
|
||||
"messages": [
|
||||
{"role": "user", "content": [item]},
|
||||
],
|
||||
# Only need 1 token so the server actually runs the encoder path
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
tasks.append(
|
||||
encode_session.post(
|
||||
f"{target_url}/v1/chat/completions",
|
||||
json=encoder_req,
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Fail fast if any sub-request failed
|
||||
for idx, r in enumerate(results):
|
||||
if isinstance(r, Exception):
|
||||
logger.error(
|
||||
"[%s] Encoder request #%d raised exception: %s",
|
||||
req_id,
|
||||
idx,
|
||||
r,
|
||||
exc_info=r,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502, detail=f"Encoder request failed: {str(r)}"
|
||||
)
|
||||
if r.status != 200:
|
||||
try:
|
||||
detail = await r.text()
|
||||
except Exception:
|
||||
detail = "<unable to read body>"
|
||||
logger.error(
|
||||
"[%s] Encoder request #%d returned status %s: %s",
|
||||
req_id,
|
||||
idx,
|
||||
r.status,
|
||||
detail,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=r.status,
|
||||
detail=f"Encoder request failed: {detail}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[%s] All %d encoder requests completed successfully", req_id, len(mm_items)
|
||||
)
|
||||
|
||||
|
||||
async def maybe_prefill(
|
||||
req_data: dict,
|
||||
p_url: str,
|
||||
req_id: str,
|
||||
) -> dict:
|
||||
"""
|
||||
- Do prefill-only task if p_url exist;
|
||||
- Return modified request data with kv transfer params (for nixl connector)
|
||||
- Else, skip and return the original request data for decode
|
||||
"""
|
||||
if p_url:
|
||||
logger.info("[%s] Processing through prefill: %s", req_id, p_url)
|
||||
|
||||
prefill_response = await process_prefill_stage(req_data, p_url, req_id)
|
||||
# for nixl connector to facilitate kv transfer...
|
||||
prefill_response_json = await prefill_response.json()
|
||||
kv_transfer_params = prefill_response_json.get("kv_transfer_params", {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
|
||||
return req_data
|
||||
else:
|
||||
return req_data
|
||||
|
||||
|
||||
async def process_prefill_stage(
|
||||
req_data: dict,
|
||||
p_url: str,
|
||||
req_id: str,
|
||||
) -> dict:
|
||||
"""Process request through Prefill stage and return kv_transfer_params"""
|
||||
logger.info("[%s] Sending prefill request to: %s", req_id, p_url)
|
||||
|
||||
prefill_request = req_data.copy()
|
||||
prefill_request["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
}
|
||||
prefill_request["stream"] = False
|
||||
prefill_request["max_tokens"] = 1
|
||||
if "max_completion_tokens" in prefill_request:
|
||||
prefill_request["max_completion_tokens"] = 1
|
||||
if "stream_options" in prefill_request:
|
||||
del prefill_request["stream_options"]
|
||||
|
||||
headers = {"x-request-id": req_id}
|
||||
try:
|
||||
prefill_response = await prefill_session.post(
|
||||
f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers
|
||||
)
|
||||
prefill_response.raise_for_status()
|
||||
|
||||
if prefill_response.status != 200:
|
||||
error_text = await prefill_response.text()
|
||||
logger.error(
|
||||
"[%s] Prefill request failed with status %d: %s",
|
||||
req_id,
|
||||
prefill_response.status,
|
||||
error_text,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=prefill_response.status,
|
||||
detail={"error": "Prefill request failed", "message": error_text},
|
||||
)
|
||||
logger.info("[%s] Prefill request completed successfully", req_id)
|
||||
|
||||
return prefill_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Prefill processing failed: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Prefill processing error", "message": str(e)},
|
||||
) from e
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Middleware for request/response logging
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
"""Middleware to log all incoming requests and responses"""
|
||||
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
||||
|
||||
# Log incoming request
|
||||
logger.info(
|
||||
">>> [%s] %s %s from %s",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
request.client.host if request.client else "unknown",
|
||||
)
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
"<<< [%s] %s %s completed with status %d",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
# Log errors
|
||||
logger.exception(
|
||||
"!!! [%s] %s %s failed with error: %s",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
###############################################################################
|
||||
# FastAPI lifecycle
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup() -> None:
|
||||
global encode_session, prefill_session, decode_session
|
||||
timeout = aiohttp.ClientTimeout(total=100_000)
|
||||
connector = aiohttp.TCPConnector(limit=0, force_close=False)
|
||||
encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
if app.state.p_urls:
|
||||
# only setup if prefill instance(s) exist
|
||||
prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def on_shutdown() -> None:
|
||||
global encode_session, prefill_session, decode_session
|
||||
if encode_session:
|
||||
await encode_session.close()
|
||||
if prefill_session:
|
||||
await prefill_session.close()
|
||||
if decode_session:
|
||||
await decode_session.close()
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Core forwarding
|
||||
###############################################################################
|
||||
|
||||
|
||||
async def forward_non_stream(
|
||||
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
|
||||
) -> dict:
|
||||
try:
|
||||
# Step 1: Process through Encoder instance (if has MM input)
|
||||
await fanout_encoder_primer(req_data, e_urls, req_id)
|
||||
|
||||
# Step 2: Process through Prefill instance
|
||||
req_data = await maybe_prefill(req_data, p_url, req_id)
|
||||
|
||||
# Step 3: Process through Decode instance
|
||||
logger.info("[%s] Forwarding to decode: %s", req_id, d_url)
|
||||
headers = {"x-request-id": req_id}
|
||||
|
||||
# Non-streaming response
|
||||
async with decode_session.post(
|
||||
f"{d_url}/v1/chat/completions", json=req_data, headers=headers
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e
|
||||
|
||||
|
||||
async def forward_stream(
|
||||
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
|
||||
) -> AsyncIterator[str]:
|
||||
try:
|
||||
# Step 1: Process through Encoder instance (if has MM input)
|
||||
await fanout_encoder_primer(req_data, e_urls, req_id)
|
||||
|
||||
# Step 2: Process through Prefill instance
|
||||
req_data = await maybe_prefill(req_data, p_url, req_id)
|
||||
|
||||
# Step 3: Process through Decode instance
|
||||
logger.info("[%s] Starting streaming from decode: %s", req_id, d_url)
|
||||
headers = {"x-request-id": req_id}
|
||||
|
||||
# Streaming response
|
||||
async with decode_session.post(
|
||||
f"{d_url}/v1/chat/completions",
|
||||
json=req_data,
|
||||
headers=headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.content.iter_chunked(1024):
|
||||
if chunk:
|
||||
yield chunk.decode("utf-8", errors="ignore")
|
||||
|
||||
logger.info("[%s] Streaming completed", req_id)
|
||||
|
||||
except HTTPException:
|
||||
logger.exception("[%s] HTTPException in forward_stream", req_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Error in forward_stream: %s", req_id, str(e))
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Proxy streaming error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Public routes
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
||||
|
||||
e_urls = app.state.e_urls # we want the full list for fan-out
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
|
||||
is_streaming = req_data.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
forward_stream(req_data, req_id, e_urls, p_url, d_url),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
result = await forward_non_stream(req_data, req_id, e_urls, p_url, d_url)
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat_completions endpoint: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Request processing error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
async def healthy(urls):
|
||||
if not urls:
|
||||
return "empty"
|
||||
for u in urls:
|
||||
try:
|
||||
async with encode_session.get(f"{u}/health") as resp:
|
||||
resp.raise_for_status()
|
||||
except Exception:
|
||||
return "unhealthy"
|
||||
return "healthy"
|
||||
|
||||
e_status, p_status, d_status = await asyncio.gather(
|
||||
healthy(app.state.e_urls), healthy(app.state.p_urls), healthy(app.state.d_urls)
|
||||
)
|
||||
|
||||
overall_healthy = all(
|
||||
status != "unhealthy" for status in (e_status, p_status, d_status)
|
||||
)
|
||||
|
||||
status_code = 200 if overall_healthy else 503
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"proxy": "healthy",
|
||||
"encode_cluster": e_status,
|
||||
"prefill_cluster": p_status,
|
||||
"decode_cluster": d_status,
|
||||
},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Simple profiler fan-out (unchanged except for sessions)
|
||||
###############################################################################
|
||||
|
||||
|
||||
async def _post_if_available(
|
||||
session: aiohttp.ClientSession,
|
||||
url: str,
|
||||
payload: dict,
|
||||
headers: dict,
|
||||
) -> dict | None:
|
||||
"""
|
||||
POST `payload` to `url`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
• The decoded JSON body on success (2xx)
|
||||
• None if the endpoint does not exist (404)
|
||||
• Raises for anything else.
|
||||
"""
|
||||
try:
|
||||
resp = await session.post(url, json=payload, headers=headers)
|
||||
if resp.status == 404: # profiling disabled on that server
|
||||
logger.warning("Profiling endpoint missing on %s", url)
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
return await resp.json(content_type=None)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
# Pass 404 through the branch above, re-raise everything else
|
||||
if exc.status == 404:
|
||||
logger.warning("Profiling endpoint missing on %s", url)
|
||||
return None
|
||||
raise
|
||||
except Exception:
|
||||
# Network errors etc.: propagate
|
||||
raise
|
||||
|
||||
|
||||
async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str):
|
||||
"""
|
||||
Fire & forget to both clusters, tolerate 404.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"}
|
||||
|
||||
encode_task = _post_if_available(
|
||||
encode_session, f"{e_url}/{cmd}_profile", payload, headers
|
||||
)
|
||||
prefill_task = (
|
||||
_post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers)
|
||||
if p_url is not None
|
||||
else asyncio.sleep(0)
|
||||
)
|
||||
decode_task = _post_if_available(
|
||||
decode_session, f"{d_url}/{cmd}_profile", payload, headers
|
||||
)
|
||||
|
||||
encode_res, prefill_res, decode_res = await asyncio.gather(
|
||||
encode_task, prefill_task, decode_task
|
||||
)
|
||||
|
||||
# If *all* clusters said “I don’t have that route”, surface an error
|
||||
if encode_res is prefill_res is decode_res is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Profiling endpoints are disabled on all clusters",
|
||||
)
|
||||
|
||||
return {
|
||||
"encode": encode_res, # may be None
|
||||
"prefill": prefill_res, # may be None
|
||||
"decode": decode_res, # may be None
|
||||
}
|
||||
|
||||
|
||||
@app.post("/start_profile")
|
||||
async def start_profile(request: Request):
|
||||
body = await request.json()
|
||||
# TODO: handle multi urls properly
|
||||
e_url = random.choice(app.state.e_urls)
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
return await _profile_cmd("start", body, e_url, p_url, d_url)
|
||||
|
||||
|
||||
@app.post("/stop_profile")
|
||||
async def stop_profile(request: Request):
|
||||
body = await request.json()
|
||||
# TODO: handle multi urls properly
|
||||
e_url = random.choice(app.state.e_urls)
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
return await _profile_cmd("stop", body, e_url, p_url, d_url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument(
|
||||
"--encode-servers-urls",
|
||||
required=True,
|
||||
help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-servers-urls",
|
||||
required=True,
|
||||
help=(
|
||||
'Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") ',
|
||||
'to enable E->P->D, set "disable" or "none" to enable E->PD',
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-servers-urls",
|
||||
required=True,
|
||||
help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
app.state.e_urls = [
|
||||
u.strip() for u in args.encode_servers_urls.split(",") if u.strip()
|
||||
]
|
||||
app.state.d_urls = [
|
||||
u.strip() for u in args.decode_servers_urls.split(",") if u.strip()
|
||||
]
|
||||
# handle prefill instances
|
||||
if args.prefill_servers_urls.lower() in ("disable", "none", ""):
|
||||
app.state.p_urls = []
|
||||
logger.info(
|
||||
"Disaggregated prefill phase explicitly disabled by user. Running E + PD..."
|
||||
)
|
||||
else:
|
||||
app.state.p_urls = [
|
||||
u.strip() for u in args.prefill_servers_urls.split(",") if u.strip()
|
||||
]
|
||||
logger.info("Disaggregated prefill phase is enabled. Running E + P + D...")
|
||||
|
||||
logger.info("Proxy listening on %s:%s", args.host, args.port)
|
||||
logger.info("Encode servers: %s", app.state.e_urls)
|
||||
logger.info("Prefill instances %s", app.state.p_urls)
|
||||
logger.info("Decode servers: %s", app.state.d_urls)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
loop="uvloop",
|
||||
access_log=True,
|
||||
)
|
||||
@ -37,7 +37,7 @@ from vllm.config import KVTransferConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
|
||||
def setup_environment_variables(vllm_version: str):
|
||||
def setup_environment_variables():
|
||||
# LMCache-related environment variables
|
||||
# Use experimental features in LMCache
|
||||
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||
@ -47,12 +47,10 @@ def setup_environment_variables(vllm_version: str):
|
||||
os.environ["LMCACHE_LOCAL_CPU"] = "True"
|
||||
# Set local CPU memory limit to 5.0 GB
|
||||
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
|
||||
if vllm_version == "v0":
|
||||
os.environ["VLLM_USE_V1"] = "0"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str):
|
||||
def build_llm_with_lmcache(lmcache_connector: str, model: str):
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector=lmcache_connector,
|
||||
kv_role="kv_both",
|
||||
@ -60,21 +58,12 @@ def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# memory. Reduce the value if your GPU has less memory.
|
||||
# Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
|
||||
if vllm_version == "v0":
|
||||
llm_args = EngineArgs(
|
||||
model=model,
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enable_chunked_prefill=True, # Only in v0
|
||||
)
|
||||
else:
|
||||
llm_args = EngineArgs(
|
||||
model=model,
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
llm_args = EngineArgs(
|
||||
model=model,
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
llm = LLM(**asdict(llm_args))
|
||||
try:
|
||||
@ -116,18 +105,10 @@ def parse_args():
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.version == "v0":
|
||||
lmcache_connector = "LMCacheConnector"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
else:
|
||||
lmcache_connector = "LMCacheConnectorV1"
|
||||
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
setup_environment_variables(args.version)
|
||||
|
||||
with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
|
||||
lmcache_connector = "LMCacheConnectorV1"
|
||||
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
setup_environment_variables()
|
||||
with build_llm_with_lmcache(lmcache_connector, model) as llm:
|
||||
# This example script runs two requests with a shared prefix.
|
||||
# Define the shared prompt and specific prompts
|
||||
shared_prompt = "Hello, how are you?" * 1000
|
||||
|
||||
@ -9,7 +9,6 @@ torch==2.9.0
|
||||
torchaudio==2.9.0
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
||||
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
||||
xformers==0.0.33; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.5.2
|
||||
|
||||
@ -9,7 +9,7 @@ torchaudio==2.9.0
|
||||
triton==3.5.0
|
||||
cmake>=3.26.1,<4
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<81.0.0
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
wheel
|
||||
jinja2>=3.1.6
|
||||
|
||||
@ -10,7 +10,7 @@ peft
|
||||
pytest-asyncio
|
||||
tensorizer==2.10.1
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<81.0.0
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer[s3,gcs]==0.15.0
|
||||
conch-triton-kernels==1.2.1
|
||||
|
||||
4
setup.py
4
setup.py
@ -208,6 +208,8 @@ class cmake_build_ext(build_ext):
|
||||
# Make sure we use the nvcc from CUDA_HOME
|
||||
if _is_cuda():
|
||||
cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"]
|
||||
elif _is_hip():
|
||||
cmake_args += [f"-DROCM_PATH={ROCM_HOME}"]
|
||||
|
||||
other_cmake_args = os.environ.get("CMAKE_ARGS")
|
||||
if other_cmake_args:
|
||||
@ -628,6 +630,7 @@ ext_modules = []
|
||||
|
||||
if _is_cuda() or _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
||||
@ -643,7 +646,6 @@ if _is_cuda():
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
|
||||
)
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
@ -8,12 +8,13 @@ import torch
|
||||
|
||||
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
def test_python_error():
|
||||
"""
|
||||
Test if Python error occurs when there's low-level
|
||||
@ -39,7 +40,7 @@ def test_python_error():
|
||||
allocator.wake_up()
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
def test_basic_cumem():
|
||||
# some tensors from default memory pool
|
||||
shape = (1024, 1024)
|
||||
@ -72,7 +73,7 @@ def test_basic_cumem():
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
def test_cumem_with_cudagraph():
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
@ -117,7 +118,7 @@ def test_cumem_with_cudagraph():
|
||||
assert torch.allclose(y, x + 1)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
|
||||
@ -203,7 +203,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -281,7 +281,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=False,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
|
||||
@ -62,7 +62,6 @@ def _run_simple_model(
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
|
||||
@ -449,7 +449,6 @@ def benchmark():
|
||||
if piecewise:
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
|
||||
@ -127,7 +127,9 @@ def test_compile_correctness(
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
for mode in [CompilationMode.NONE, comp_mode]:
|
||||
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"])
|
||||
all_args.append(
|
||||
final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"]
|
||||
)
|
||||
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
@ -146,7 +148,7 @@ def test_compile_correctness(
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"])
|
||||
all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"])
|
||||
all_envs.append({})
|
||||
all_envs.append({})
|
||||
|
||||
|
||||
@ -2,8 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
@ -11,7 +13,7 @@ from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
|
||||
def test_version():
|
||||
@ -23,14 +25,6 @@ def test_version():
|
||||
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||
|
||||
|
||||
def test_use_cudagraphs_dynamic():
|
||||
vllm_config = VllmConfig()
|
||||
# Default V1 configuration now starts without cudagraphs enabled; the
|
||||
# engine decides when to capture based on runtime settings instead of a
|
||||
# blanket default.
|
||||
assert vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
|
||||
def test_copy_pass():
|
||||
vllm_config = VllmConfig()
|
||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||
@ -65,7 +59,7 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
|
||||
|
||||
compilation_config = {
|
||||
"use_cudagraph": False, # speed things up a bit
|
||||
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
|
||||
}
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
@ -83,20 +77,31 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("enabled", [True, False])
|
||||
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||
@pytest.mark.parametrize(
|
||||
"cudagraph_mode,num_cudagraph_captured",
|
||||
[
|
||||
(CUDAGraphMode.NONE, 0),
|
||||
(CUDAGraphMode.FULL_DECODE_ONLY, 1),
|
||||
(CUDAGraphMode.PIECEWISE, 13),
|
||||
(CUDAGraphMode.FULL_AND_PIECEWISE, 14),
|
||||
],
|
||||
)
|
||||
def test_use_cudagraphs(
|
||||
vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
|
||||
):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [100],
|
||||
"use_cudagraph": enabled,
|
||||
"cudagraph_mode": cudagraph_mode,
|
||||
}
|
||||
num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_gpu_runner_capture_triggers=1 if enabled else 0,
|
||||
num_cudagraph_captured=13 if enabled else 0,
|
||||
num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
|
||||
num_cudagraph_captured=num_cudagraph_captured,
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
@ -168,19 +173,18 @@ def test_splitting_ops_dynamic():
|
||||
assert not config.compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
# When use_inductor_graph_partition=True
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
)
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
|
||||
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
|
||||
# When attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
@ -189,29 +193,41 @@ def test_splitting_ops_dynamic():
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
# With the new simplified logic, attention fusion works with splitting_ops
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# cudagraph mode remains PIECEWISE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# cudagraph mode also fall back to FULL
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
# splitting_ops can not contain attention ops when attn_fusion
|
||||
# pass enabled.
|
||||
with pytest.raises(ValidationError):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
)
|
||||
)
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_should_split():
|
||||
@ -293,25 +309,36 @@ def test_should_split():
|
||||
"tp_size",
|
||||
"enable_sequence_parallelism",
|
||||
"max_num_batched_tokens",
|
||||
"use_cudagraph",
|
||||
"cudagraph_mode",
|
||||
"expected_max_size",
|
||||
),
|
||||
[
|
||||
(None, None, 1, False, 2048, True, 512),
|
||||
([1, 2, 4], 4, 1, False, 2048, True, 4),
|
||||
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
|
||||
([1, 256], None, 1, False, 2048, 256),
|
||||
([], None, 1, False, 2048, False, 0),
|
||||
(None, 0, 1, False, 2048, False, 0),
|
||||
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
(
|
||||
[1, 2, 4],
|
||||
8,
|
||||
1,
|
||||
False,
|
||||
2048,
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
ValidationError,
|
||||
),
|
||||
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
|
||||
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
|
||||
# truncated to nearest multiple of 8 or 16
|
||||
(None, 257, 1, False, 2048, True, 256),
|
||||
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
|
||||
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
|
||||
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
|
||||
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
# max from list
|
||||
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
|
||||
# filtered out 15 due to SP
|
||||
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# limited by the max_tokens
|
||||
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# the list should contain at least 1 element when use cudagraph
|
||||
([], None, 1, False, 2048, True, RuntimeError),
|
||||
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
|
||||
# the max capturing size should be >= 1 when use cudagraph
|
||||
(None, 0, 1, False, 2048, True, RuntimeError),
|
||||
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
|
||||
],
|
||||
)
|
||||
def test_cudagraph_sizes_post_init(
|
||||
@ -320,15 +347,17 @@ def test_cudagraph_sizes_post_init(
|
||||
tp_size,
|
||||
enable_sequence_parallelism,
|
||||
max_num_batched_tokens,
|
||||
use_cudagraph,
|
||||
cudagraph_mode,
|
||||
expected_max_size,
|
||||
):
|
||||
ctx = nullcontext()
|
||||
if isinstance(expected_max_size, Exception):
|
||||
if expected_max_size == ValidationError:
|
||||
ctx = pytest.raises(expected_max_size)
|
||||
|
||||
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
|
||||
with ctx:
|
||||
with (
|
||||
ctx,
|
||||
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
|
||||
):
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
||||
@ -342,11 +371,13 @@ def test_cudagraph_sizes_post_init(
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_seqs=min(max_num_batched_tokens, 128),
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
|
||||
)
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
@ -80,7 +80,6 @@ def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatc
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -215,7 +214,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -257,7 +255,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
@ -184,13 +185,24 @@ def test_custom_compile_config(
|
||||
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
"model, backend",
|
||||
[
|
||||
"Qwen/Qwen2-0.5B", # Standard attention model
|
||||
"deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model
|
||||
("Qwen/Qwen2-0.5B", None), # Standard attention model
|
||||
(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
), # MLA (Multi-head Latent Attention) model
|
||||
],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(compilation_mode: int, model: str):
|
||||
def test_fp8_kv_scale_compile(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
compilation_mode: int,
|
||||
model: str,
|
||||
backend: AttentionBackendEnum | None,
|
||||
):
|
||||
if backend:
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs = {
|
||||
"quantization": "fp8",
|
||||
"kv_cache_dtype": "fp8_e4m3",
|
||||
|
||||
@ -10,7 +10,7 @@ from tests.utils import flat_product
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
@ -105,7 +105,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
|
||||
# TODO(luka) use get_kv_cache_stride_order
|
||||
# Create dummy KV cache for the selected backend
|
||||
if backend == _Backend.ROCM_ATTN:
|
||||
if backend == AttentionBackendEnum.ROCM_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
@ -117,7 +117,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
@ -129,7 +129,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.TRITON_ATTN:
|
||||
elif backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
# k/v as 2nd dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
@ -141,7 +141,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
elif backend == AttentionBackendEnum.FLASHINFER:
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
@ -242,8 +242,8 @@ MODELS_FP8: list[tuple[str, type]] = []
|
||||
MODELS_FP4: list[tuple[str, type]] = []
|
||||
HEADS: list[tuple[int, int]] = []
|
||||
SPLIT_ATTENTION: list[bool] = []
|
||||
BACKENDS_FP8: list[_Backend] = []
|
||||
BACKENDS_FP4: list[_Backend] = []
|
||||
BACKENDS_FP8: list[AttentionBackendEnum] = []
|
||||
BACKENDS_FP4: list[AttentionBackendEnum] = []
|
||||
|
||||
if current_platform.is_cuda():
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
@ -259,8 +259,8 @@ if current_platform.is_cuda():
|
||||
TestAttentionNvfp4QuantPatternModel,
|
||||
)
|
||||
]
|
||||
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
|
||||
BACKENDS_FP4 = [_Backend.FLASHINFER]
|
||||
BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
|
||||
BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
@ -268,9 +268,9 @@ elif current_platform.is_rocm():
|
||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||
]
|
||||
BACKENDS = [
|
||||
_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
_Backend.ROCM_ATTN,
|
||||
_Backend.TRITON_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
]
|
||||
|
||||
|
||||
@ -300,11 +300,11 @@ def test_attention_quant_pattern(
|
||||
custom_ops: str,
|
||||
model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
dist_init,
|
||||
):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
@ -312,6 +312,7 @@ def test_attention_quant_pattern(
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(42)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
@ -400,7 +401,7 @@ def test_attention_quant_pattern(
|
||||
|
||||
result_fused_1 = model_compiled(q, k, v)
|
||||
|
||||
if backend == _Backend.FLASHINFER:
|
||||
if backend == AttentionBackendEnum.FLASHINFER:
|
||||
# With the Flashinfer backend after the 1st round of the forward
|
||||
# pass, output quant scale should be loaded into the attn layer's
|
||||
# _o_scale_float, the 2nd round should reuse the loaded
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any, NamedTuple
|
||||
import pytest
|
||||
import regex as re
|
||||
|
||||
from tests.v1.attention.utils import _Backend
|
||||
from tests.v1.attention.utils import AttentionBackendEnum
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test
|
||||
class ModelBackendTestCase(NamedTuple):
|
||||
model_name: str
|
||||
model_kwargs: dict[str, Any]
|
||||
backend: _Backend
|
||||
backend: AttentionBackendEnum
|
||||
attention_fusions: int
|
||||
allreduce_fusions: int | None = None
|
||||
|
||||
@ -39,14 +39,14 @@ if current_platform.is_cuda():
|
||||
# Use smaller model for L40s in CI
|
||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
),
|
||||
@ -56,7 +56,7 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
@ -67,14 +67,14 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=97,
|
||||
),
|
||||
@ -85,19 +85,19 @@ elif current_platform.is_rocm():
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_ATTN,
|
||||
backend=AttentionBackendEnum.ROCM_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
]
|
||||
@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
@ -125,7 +125,7 @@ def test_attn_quant(
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
|
||||
195
tests/compile/test_qk_norm_rope_fusion.py
Normal file
195
tests/compile/test_qk_norm_rope_fusion.py
Normal file
@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.qk_norm_rope_fusion import (
|
||||
FUSED_QK_ROPE_OP,
|
||||
QKNormRoPEFusionPass,
|
||||
)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RSQRT_OP = torch.ops.aten.rsqrt.default
|
||||
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
|
||||
|
||||
|
||||
class QKNormRoPETestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
vllm_config: VllmConfig,
|
||||
dtype: torch.dtype,
|
||||
prefix: str = "model.layers.0.self_attn.attn",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = num_heads * head_dim
|
||||
self.kv_size = num_kv_heads * head_dim
|
||||
self.rotary_dim = head_dim
|
||||
self.eps = eps
|
||||
self.dtype = dtype
|
||||
|
||||
# Register layer metadata for the fusion pass via Attention.
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=1.0 / self.head_dim**0.5,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
prefix=prefix,
|
||||
attn_type=AttentionType.DECODER,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=self.eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=self.eps)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000,
|
||||
is_neox_style=is_neox,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.enable_rms_norm_custom_op = self.q_norm.enabled()
|
||||
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
||||
|
||||
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
return q, k, v
|
||||
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rms_norm_custom_op:
|
||||
ops.append(RMS_OP)
|
||||
else:
|
||||
ops.append(RSQRT_OP)
|
||||
|
||||
if self.enable_rope_custom_op:
|
||||
if self.rotary_emb.use_flashinfer:
|
||||
ops.append(FLASHINFER_ROTARY_OP)
|
||||
else:
|
||||
ops.append(ROTARY_OP)
|
||||
else:
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
||||
return [FUSED_QK_ROPE_OP]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("is_neox", [True, False])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Only test on cuda and rocm platform",
|
||||
)
|
||||
def test_qk_norm_rope_fusion(
|
||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||
):
|
||||
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
|
||||
pytest.skip("fused_qk_norm_rope custom op not available")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
custom_ops: list[str] = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_rope_custom_op:
|
||||
custom_ops.append("+rotary_embedding")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_noop=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
T = 5
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = QKNormRoPETestModel(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
eps=eps,
|
||||
is_neox=is_neox,
|
||||
vllm_config=vllm_config,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = QKNormRoPEFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend_baseline = TestBackend(noop_pass, cleanup_pass)
|
||||
|
||||
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)
|
||||
pos = torch.arange(T, dtype=torch.long, device=qkv.device)
|
||||
qkv_unfused = qkv.clone()
|
||||
pos_unfused = pos.clone()
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
torch._dynamo.mark_dynamic(pos, 0)
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
q_fused, k_fused, v_fused = model_fused(qkv, pos)
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv_unfused, 0)
|
||||
torch._dynamo.mark_dynamic(pos_unfused, 0)
|
||||
model_unfused = torch.compile(model, backend=backend_baseline)
|
||||
q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
@ -3,13 +3,13 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_str_conversion():
|
||||
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
|
||||
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
|
||||
assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_invalid():
|
||||
@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid():
|
||||
def test_mm_encoder_attn_backend_hash_updates():
|
||||
base_hash = MultiModalConfig().compute_hash()
|
||||
overridden_hash = MultiModalConfig(
|
||||
mm_encoder_attn_backend=_Backend.FLASH_ATTN
|
||||
mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
|
||||
).compute_hash()
|
||||
assert base_hash != overridden_hash
|
||||
|
||||
@ -22,9 +22,6 @@ def monkeypatch_module():
|
||||
|
||||
@pytest.fixture(scope="module", params=[True])
|
||||
def server(request, monkeypatch_module):
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
|
||||
@ -34,6 +34,9 @@ class MockConversationContext(ConversationContext):
|
||||
def append_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
async def call_tool(self):
|
||||
return []
|
||||
|
||||
|
||||
30
tests/entrypoints/test_responses_utils.py
Normal file
30
tests/entrypoints/test_responses_utils.py
Normal file
@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.entrypoints.responses_utils import (
|
||||
convert_tool_responses_to_completions_format,
|
||||
)
|
||||
|
||||
|
||||
class TestResponsesUtils:
|
||||
"""Tests for convert_tool_responses_to_completions_format function."""
|
||||
|
||||
def test_convert_tool_responses_to_completions_format(self):
|
||||
"""Test basic conversion of a flat tool schema to nested format."""
|
||||
input_tool = {
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location", "unit"],
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_tool_responses_to_completions_format(input_tool)
|
||||
|
||||
assert result == {"type": "function", "function": input_tool}
|
||||
@ -35,7 +35,7 @@ DEVICE_MLA_BACKENDS = {
|
||||
DEVICE_REGULAR_ATTN_BACKENDS = {
|
||||
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
|
||||
"hip": ["ROCM_ATTN"],
|
||||
"cpu": ["TORCH_SDPA"],
|
||||
"cpu": ["CPU_ATTN"],
|
||||
}
|
||||
|
||||
DEVICE_MLA_BLOCK_SIZES = {
|
||||
@ -86,7 +86,7 @@ def test_env(
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "hip":
|
||||
with patch("vllm.platforms.current_platform", RocmPlatform()):
|
||||
@ -120,12 +120,13 @@ def test_env(
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if use_mla:
|
||||
# CUDA MLA backend logic:
|
||||
# - CUTLASS_MLA: only supported with block_size == 128
|
||||
# and Blackwell GPUs (SM 10.0), V1 only
|
||||
# and Blackwell GPUs (SM 10.x), V1 only
|
||||
# - FLASHINFER_MLA: only supported on Blackwell GPUs
|
||||
# (SM 10.0+), V1 only
|
||||
# (SM 10.x), V1 only
|
||||
# - FLASHMLA: only supported with block_size == 64
|
||||
# - FLASH_ATTN_MLA: V1 only
|
||||
# - TRITON_MLA: fallback for other cases
|
||||
@ -134,58 +135,72 @@ def test_env(
|
||||
if block_size != 128:
|
||||
# CUTLASS_MLA only supports block_size == 128
|
||||
pytest.skip("CUTLASS_MLA only supports block_size 128")
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
if capability[0] != 10:
|
||||
pytest.skip("CUTLASS MLA is not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER_MLA":
|
||||
if capability[0] != 10:
|
||||
pytest.skip(
|
||||
"FlashInfer MLA is not supported on this platform"
|
||||
)
|
||||
if block_size not in [32, 64]:
|
||||
# FlashInfer MLA only supports block_size 32 or 64
|
||||
pytest.skip(
|
||||
"FlashInfer MLA only supports block_size 32 or 64"
|
||||
)
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHMLA":
|
||||
if block_size != 64:
|
||||
# FlashMLA only supports block_size == 64
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
else:
|
||||
from vllm.v1.attention.backends.mla.flashmla import (
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla import (
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
|
||||
is_supported, _ = is_flashmla_dense_supported()
|
||||
if not is_supported:
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
is_supported, _ = is_flashmla_dense_supported()
|
||||
if not is_supported:
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
576,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
)
|
||||
|
||||
if not flash_attn_supports_mla():
|
||||
pytest.skip(
|
||||
"FlashAttention MLA not supported on this platform"
|
||||
)
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "TRITON_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER":
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
64, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER"
|
||||
assert backend.get_name() == expected
|
||||
@ -209,7 +224,7 @@ def test_fp32_fallback(device: str):
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
|
||||
575
tests/kernels/attention/test_cpu_attn.py
Normal file
575
tests/kernels/attention/test_cpu_attn.py
Normal file
@ -0,0 +1,575 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||
|
||||
from vllm._custom_ops import (
|
||||
cpu_attention_with_kv_cache,
|
||||
cpu_attn_get_scheduler_metadata,
|
||||
cpu_attn_reshape_and_cache,
|
||||
)
|
||||
|
||||
NUM_HEADS = [
|
||||
(4, 4),
|
||||
(8, 2),
|
||||
(9, 3),
|
||||
]
|
||||
HEAD_SIZES = [96, 128]
|
||||
QTYPES = [torch.bfloat16, torch.half, torch.float32]
|
||||
SLIDING_WINDOWS = [None, 256]
|
||||
NUM_BLOCKS = [
|
||||
1024,
|
||||
]
|
||||
SEQ_LENS = [ # (q_len, kv_len)
|
||||
[(1, 213), (1, 1), (1, 312), (1, 7), (1, 7812)], # decode batch
|
||||
[(2345, 2345), (5, 5), (3, 16), (134, 5131)], # prefill batch
|
||||
[(992, 2456), (1, 1234), (98, 1145), (1, 4162), (2345, 2345)], # mixed batch
|
||||
]
|
||||
|
||||
|
||||
# rand number generation takes too much time, cache rand tensors
|
||||
@functools.lru_cache(maxsize=128, typed=False)
|
||||
def tensor_cache(
|
||||
elem_num: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.randn(elem_num, dtype=dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(
|
||||
closest_power_of_2, total_num_heads - closest_power_of_2
|
||||
)
|
||||
extra_powers = torch.arange(
|
||||
start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
|
||||
)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes.float()
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: list[int],
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
dtype = query.dtype
|
||||
|
||||
outputs: list[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = alibi_slopes[:, None, None]
|
||||
|
||||
if s_aux is not None:
|
||||
s_aux = s_aux.float()
|
||||
s_aux = s_aux[:, None, None]
|
||||
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx : start_idx + query_len].float()
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_kv_blocks]
|
||||
|
||||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
k = k[:kv_len].float()
|
||||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
v = v[:kv_len].float()
|
||||
|
||||
if q.shape[1] != k.shape[1]:
|
||||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
q_start_pos = kv_len - query_len
|
||||
q_pos = q_start_pos + torch.arange(0, query_len)[None, :, None]
|
||||
kv_pos = torch.arange(0, kv_len)[None, None, :]
|
||||
dist = q_pos - kv_pos
|
||||
alibi_bias = -alibi_slopes * dist
|
||||
attn += alibi_bias
|
||||
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
|
||||
if s_aux is not None:
|
||||
s_aux_ext = s_aux.repeat(1, query_len, 1)
|
||||
attn = torch.cat((s_aux_ext, attn), dim=-1)
|
||||
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
if s_aux is not None:
|
||||
attn = attn[:, :, 1:]
|
||||
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v).to(dtype=dtype)
|
||||
|
||||
outputs.append(out)
|
||||
start_idx += query_len
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def varlen_with_paged_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
token_num = sum(query_lens)
|
||||
|
||||
# for n heads the set of slopes is the geometric sequence that starts
|
||||
# 2^(-8/n)
|
||||
alibi_slopes = _get_alibi_slopes(num_query_heads) if use_alibi else None
|
||||
|
||||
s_aux = (
|
||||
15 * torch.rand((num_query_heads,), dtype=torch.bfloat16) if use_sink else None
|
||||
)
|
||||
|
||||
query = tensor_cache(
|
||||
elem_num=token_num * num_query_heads * head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
query = query.view(
|
||||
token_num,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
)
|
||||
|
||||
key_value = tensor_cache(
|
||||
elem_num=2 * num_blocks * num_kv_heads * block_size * head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
key_value = key_value.view(
|
||||
2,
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
)
|
||||
key_cache, value_cache = key_value.unbind(0)
|
||||
|
||||
# KV cache for CPU attention
|
||||
packed_key_cache = torch.empty(
|
||||
num_blocks, num_kv_heads, block_size, head_size, dtype=dtype
|
||||
)
|
||||
packed_value_cache = torch.empty_like(packed_key_cache)
|
||||
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
# use reshape_and_cache to pack key_cache and value_cache
|
||||
slot_mapping = torch.arange(0, num_blocks * block_size, dtype=torch.int64)
|
||||
cpu_attn_reshape_and_cache(
|
||||
key=key_cache.view(-1, num_kv_heads, head_size),
|
||||
value=value_cache.view(-1, num_kv_heads, head_size),
|
||||
key_cache=packed_key_cache,
|
||||
value_cache=packed_value_cache,
|
||||
slot_mapping=slot_mapping,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
metadata = cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_seqs,
|
||||
num_heads=num_query_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
seq_lens=kv_lens_tensor,
|
||||
dtype=dtype,
|
||||
query_start_loc=cu_query_lens,
|
||||
causal=True,
|
||||
sliding_window_size=sliding_window if sliding_window is not None else -1,
|
||||
isa=isa,
|
||||
enable_kv_split=False,
|
||||
)
|
||||
|
||||
out_without_split = torch.empty_like(query)
|
||||
cpu_attention_with_kv_cache(
|
||||
query=query,
|
||||
key_cache=packed_key_cache,
|
||||
value_cache=packed_value_cache,
|
||||
output=out_without_split,
|
||||
query_start_loc=cu_query_lens,
|
||||
seq_lens=kv_lens_tensor,
|
||||
scale=scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
scheduler_metadata=metadata,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
|
||||
metadata = cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_seqs,
|
||||
num_heads=num_query_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
seq_lens=kv_lens_tensor,
|
||||
dtype=dtype,
|
||||
query_start_loc=cu_query_lens,
|
||||
causal=True,
|
||||
sliding_window_size=sliding_window if sliding_window is not None else -1,
|
||||
isa=isa,
|
||||
enable_kv_split=True,
|
||||
)
|
||||
|
||||
out_with_split = torch.empty_like(query)
|
||||
cpu_attention_with_kv_cache(
|
||||
query=query,
|
||||
key_cache=packed_key_cache,
|
||||
value_cache=packed_value_cache,
|
||||
output=out_with_split,
|
||||
query_start_loc=cu_query_lens,
|
||||
seq_lens=kv_lens_tensor,
|
||||
scale=scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
scheduler_metadata=metadata,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
(
|
||||
torch.testing.assert_close(out_with_split, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(out_with_split - ref_output))}",
|
||||
)
|
||||
(
|
||||
torch.testing.assert_close(out_without_split, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(out_without_split - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", [96, 128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", QTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["vec"])
|
||||
def test_varlen_with_paged_kv_normal_vec(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", [96, 128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["amx"])
|
||||
@pytest.mark.skipif(
|
||||
not torch._C._cpu._is_amx_tile_supported(), reason="no AMX support."
|
||||
)
|
||||
def test_varlen_with_paged_kv_normal_amx(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", [48])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["vec16"])
|
||||
def test_varlen_with_paged_kv_normal_vec16(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@pytest.mark.parametrize("block_size", [128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [50])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize(
|
||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
)
|
||||
def test_varlen_with_paged_kv_softcap(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@pytest.mark.parametrize("block_size", [128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [True])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize(
|
||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
)
|
||||
def test_varlen_with_paged_kv_alibi(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@pytest.mark.parametrize("block_size", [128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
)
|
||||
def test_varlen_with_paged_kv_sink(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
@ -11,7 +11,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.selector import _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str):
|
||||
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
elif device == "hip":
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", RocmPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
else:
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str):
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA not available
|
||||
@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str):
|
||||
),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
assert attn.attn_backend == AttentionBackendEnum.XFORMERS
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA available
|
||||
@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str):
|
||||
),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
|
||||
def ref_attention(
|
||||
|
||||
@ -8,10 +8,8 @@ from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.kernels.utils import make_alibi_bias
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.platforms import current_platform
|
||||
@ -28,6 +26,74 @@ KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
|
||||
|
||||
|
||||
def create_causal_attention_mask_for_sdpa(
|
||||
query_lens: list[int],
|
||||
seq_lens: list[int],
|
||||
sliding_window: int = 0,
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
total_queries = sum(query_lens)
|
||||
total_keys = sum(seq_lens)
|
||||
|
||||
# Create a mask filled with -inf
|
||||
mask = torch.full(
|
||||
(total_queries, total_keys), float("-inf"), device=device, dtype=dtype
|
||||
)
|
||||
|
||||
query_start = 0
|
||||
key_start = 0
|
||||
|
||||
for query_len, seq_len in zip(query_lens, seq_lens):
|
||||
query_end = query_start + query_len
|
||||
key_end = key_start + seq_len
|
||||
q_indices = torch.arange(query_len, device=device)
|
||||
k_indices = torch.arange(seq_len, device=device)
|
||||
q_pos_in_seq = seq_len - query_len + q_indices
|
||||
|
||||
valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]
|
||||
|
||||
if sliding_window > 0:
|
||||
valid_mask &= k_indices[None, :] >= (
|
||||
q_pos_in_seq[:, None] - sliding_window + 1
|
||||
)
|
||||
|
||||
mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0
|
||||
|
||||
query_start = query_end
|
||||
key_start = key_end
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def create_alibi_causal_mask(
|
||||
query_len: int,
|
||||
seq_len: int,
|
||||
alibi_slopes: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
query_pos = torch.arange(
|
||||
seq_len - query_len, seq_len, device=device, dtype=torch.float32
|
||||
)
|
||||
key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
rel_pos = key_pos[None, :] - query_pos[:, None]
|
||||
|
||||
# Apply ALiBi slopes: [num_heads, query_len, seq_len]
|
||||
alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
|
||||
alibi_bias = alibi_bias.to(dtype)
|
||||
|
||||
# Apply causal mask: prevent attending to future positions
|
||||
# causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
|
||||
causal_mask = key_pos[None, :] <= query_pos[:, None]
|
||||
alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))
|
||||
|
||||
# Add batch dimension: [1, num_heads, query_len, seq_len]
|
||||
# SDPA expects batch dimension even for single sequences
|
||||
return alibi_bias.unsqueeze(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@ -52,6 +118,13 @@ def test_contexted_kv_attention(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and op is chunked_prefill_paged_decode
|
||||
and kv_cache_dtype == "fp8_e5m2"
|
||||
):
|
||||
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -96,16 +169,16 @@ def test_contexted_kv_attention(
|
||||
)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = torch.arange(0, cache_size, dtype=torch.int32)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
@ -189,56 +262,57 @@ def test_contexted_kv_attention(
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
if num_kv_heads != num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
query = query.view(
|
||||
query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1]
|
||||
)
|
||||
key = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
query_lens, seq_lens
|
||||
# Reshape for SDPA: (seq_len, num_heads, head_size) ->
|
||||
# (1, num_heads, seq_len, head_size)
|
||||
query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
|
||||
query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, num_tokens, head_size
|
||||
)
|
||||
if sliding_window > 0:
|
||||
attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window)
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
|
||||
# Expand key and value for GQA/MQA to match query heads
|
||||
key_sdpa = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, sum(seq_lens), head_size
|
||||
)
|
||||
|
||||
value_sdpa = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, sum(seq_lens), head_size
|
||||
)
|
||||
|
||||
attn_mask = create_causal_attention_mask_for_sdpa(
|
||||
query_lens, seq_lens, sliding_window, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
output_ref = F.scaled_dot_product_attention(
|
||||
query_sdpa,
|
||||
key_sdpa,
|
||||
value_sdpa,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
output_ref = F.scaled_dot_product_attention(
|
||||
query_sdpa,
|
||||
key_sdpa,
|
||||
value_sdpa,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
output_ref = output_ref.reshape(output.shape)
|
||||
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
|
||||
# Reshape output back to (num_tokens, num_heads, head_size)
|
||||
output_ref = output_ref.view(num_heads, num_tokens, head_size)
|
||||
output_ref = output_ref.permute(1, 0, 2).contiguous()
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and op is chunked_prefill_paged_decode
|
||||
and kv_cache_dtype == "fp8_e5m2"
|
||||
):
|
||||
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -331,16 +412,16 @@ def test_contexted_kv_attention_alibi(
|
||||
)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = torch.arange(0, cache_size, dtype=torch.int32)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
@ -423,78 +504,75 @@ def test_contexted_kv_attention_alibi(
|
||||
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
|
||||
# we have to pad query tensor before MQA/GQA expanding.
|
||||
if query.shape[0] != key.shape[0]:
|
||||
query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype)
|
||||
query_pad.uniform_(-1e-3, 1e-3)
|
||||
seq_start = 0
|
||||
query_start = 0
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
query_pad[seq_start:seq_end, ...] = torch.cat(
|
||||
[
|
||||
torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype),
|
||||
query[query_start:query_end, ...],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
seq_start += seq_len
|
||||
query_start += query_len
|
||||
query = query_pad
|
||||
# Prepare query, key, value for SDPA
|
||||
# Expand key and value for GQA/MQA to match query heads
|
||||
key_expanded = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value_expanded = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
|
||||
if num_kv_heads != num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
key = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
|
||||
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
|
||||
# codebase. We save some time reshaping alibi matrix at runtime.
|
||||
key = key.reshape(key.shape[0], -1, key.shape[-1])
|
||||
value = value.reshape(value.shape[0], -1, value.shape[-1])
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
||||
output_ref = torch.empty_like(output)
|
||||
seq_start = 0
|
||||
query_start = 0
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(DefTruth): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
# modified from: vllm/v1/attention/backends/xformers.py#L343
|
||||
|
||||
query_start = 0
|
||||
key_start = 0
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[:, seq_start:seq_end],
|
||||
key[:, seq_start:seq_end],
|
||||
value[:, seq_start:seq_end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
key_end = key_start + seq_len
|
||||
|
||||
# Get query, key, value for this sequence
|
||||
q = query[query_start:query_end] # [query_len, num_heads, head_size]
|
||||
k = key_expanded[
|
||||
key_start:key_end
|
||||
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
|
||||
v = value_expanded[
|
||||
key_start:key_end
|
||||
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
|
||||
|
||||
# Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
|
||||
q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
|
||||
q_sdpa = (
|
||||
q_sdpa.permute(1, 2, 0, 3)
|
||||
.reshape(1, num_heads, query_len, head_size)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
k_sdpa = (
|
||||
k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
|
||||
)
|
||||
v_sdpa = (
|
||||
v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
|
||||
)
|
||||
|
||||
# Create ALiBi causal mask for this sequence using utility function
|
||||
alibi_mask = create_alibi_causal_mask(
|
||||
query_len, seq_len, alibi_slopes, device, dtype
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
out = F.scaled_dot_product_attention(
|
||||
q_sdpa,
|
||||
k_sdpa,
|
||||
v_sdpa,
|
||||
attn_mask=alibi_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
)
|
||||
out = out.view_as(query[:, seq_start:seq_end]).view(
|
||||
seq_len, num_heads, head_size
|
||||
)
|
||||
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...])
|
||||
seq_start += seq_len
|
||||
query_start += query_len
|
||||
|
||||
# Reshape output back to [query_len, num_heads, head_size]
|
||||
out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
|
||||
output_ref[query_start:query_end].copy_(out)
|
||||
|
||||
query_start = query_end
|
||||
key_start = key_end
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
141
tests/kernels/core/test_fused_qk_norm_rope.py
Normal file
141
tests/kernels/core/test_fused_qk_norm_rope.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
IS_NEOX = [True, False]
|
||||
EPS_VALUES = [1e-5, 1e-6]
|
||||
SEEDS = [13]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def _apply_qk_norm_rope(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_norm: RMSNorm,
|
||||
k_norm: RMSNorm,
|
||||
rope: RotaryEmbedding,
|
||||
num_heads_q: int,
|
||||
num_heads_kv: int,
|
||||
head_dim: int,
|
||||
) -> torch.Tensor:
|
||||
q_size = num_heads_q * head_dim
|
||||
kv_size = num_heads_kv * head_dim
|
||||
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
q_by_head = q_norm.forward_native(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
k_by_head = k_norm.forward_native(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
q, k = rope.forward_native(positions, q, k)
|
||||
return torch.cat([q, k, v], dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
|
||||
)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("is_neox", IS_NEOX)
|
||||
@pytest.mark.parametrize("eps", EPS_VALUES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_fused_qk_norm_rope_matches_reference(
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
is_neox: bool,
|
||||
eps: float,
|
||||
seed: int,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
num_tokens = 4
|
||||
|
||||
total_dim = (num_heads + 2 * num_kv_heads) * head_dim
|
||||
qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device)
|
||||
qkv_fused = qkv_base.clone()
|
||||
positions = torch.arange(num_tokens, dtype=torch.long, device=device)
|
||||
|
||||
q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
|
||||
k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
|
||||
q_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
k_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
q_weight = q_norm.weight.data
|
||||
k_weight = k_norm.weight.data
|
||||
|
||||
rope = RotaryEmbedding(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000.0,
|
||||
is_neox_style=is_neox,
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
ref_result = _apply_qk_norm_rope(
|
||||
qkv=qkv_base,
|
||||
positions=positions,
|
||||
q_norm=q_norm,
|
||||
k_norm=k_norm,
|
||||
rope=rope,
|
||||
num_heads_q=num_heads,
|
||||
num_heads_kv=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.fused_qk_norm_rope,
|
||||
(
|
||||
qkv_fused.clone(),
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
rope.cos_sin_cache,
|
||||
is_neox,
|
||||
positions.view(-1),
|
||||
),
|
||||
)
|
||||
|
||||
torch.ops._C.fused_qk_norm_rope(
|
||||
qkv_fused,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
rope.cos_sin_cache,
|
||||
is_neox,
|
||||
positions.view(-1),
|
||||
)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(
|
||||
qkv_fused,
|
||||
ref_result,
|
||||
atol=ATOL,
|
||||
rtol=RTOL,
|
||||
)
|
||||
@ -6,6 +6,10 @@ import torch
|
||||
|
||||
# Fused experts and PrepareFinalize imports
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
@ -21,7 +25,6 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
@ -399,9 +402,7 @@ def make_prepare_finalize(
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
if backend != "naive" and backend is not None:
|
||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||
moe, quant_config
|
||||
)
|
||||
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
# Specific (N, K, M) combinations for targeted testing
|
||||
@ -85,7 +86,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
cu_count = get_cu_count()
|
||||
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
|
||||
@ -102,7 +103,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
cu_count = get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
@ -121,7 +122,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
cu_count = get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
@ -153,7 +154,14 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
|
||||
)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count())
|
||||
out = ops.wvSplitKQ(
|
||||
B,
|
||||
A,
|
||||
dtype,
|
||||
scale_a,
|
||||
scale_b,
|
||||
get_cu_count(),
|
||||
)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
@ -180,7 +188,13 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
||||
)
|
||||
out = ops.wvSplitKQ(
|
||||
B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS
|
||||
B,
|
||||
A,
|
||||
dtype,
|
||||
scale_a,
|
||||
scale_b,
|
||||
get_cu_count(),
|
||||
BIAS,
|
||||
)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for FlexAttention backend vs default backend"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -1,516 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the triton_flash_attention kernel
|
||||
|
||||
Run `pytest tests/kernels/test_triton_flash_attention.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import (
|
||||
SUPPORTED_LAYOUTS,
|
||||
MetaData,
|
||||
compute_alibi_tensor,
|
||||
scale_fp8,
|
||||
triton_attention_rocm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class ReferenceAttention:
|
||||
def __init__(
|
||||
self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata
|
||||
):
|
||||
self.Z = Z
|
||||
self.HQ = HQ
|
||||
self.HK = HK
|
||||
self.N_CTX_Q = N_CTX_Q
|
||||
self.N_CTX_K = N_CTX_K
|
||||
self.D_HEAD = D_HEAD
|
||||
self.use_alibi = use_alibi
|
||||
self.dtype = dtype
|
||||
self.input_metadata = input_metadata
|
||||
|
||||
def fwd(self, q, k, v):
|
||||
scores = (
|
||||
torch.einsum("bhqd,bhkd->bhqk", q, k).float() * self.input_metadata.sm_scale
|
||||
)
|
||||
if self.input_metadata.causal:
|
||||
mask = torch.tril(
|
||||
torch.ones(self.N_CTX_Q, self.N_CTX_K, device="cuda"),
|
||||
diagonal=self.N_CTX_K - self.N_CTX_Q,
|
||||
)
|
||||
scores[:, :, mask == 0] = float("-inf")
|
||||
|
||||
if self.input_metadata.bias is not None:
|
||||
scores += self.input_metadata.bias
|
||||
|
||||
if self.use_alibi:
|
||||
scores += compute_alibi_tensor(
|
||||
self.input_metadata.alibi_slopes, self.N_CTX_Q, self.N_CTX_K
|
||||
)
|
||||
|
||||
p = torch.softmax(scores, dim=-1)
|
||||
if self.input_metadata.causal:
|
||||
# If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going
|
||||
# into softmax. This creates a row of NaNs as -inf - -inf == NaN.
|
||||
# So we fix this by converting the NaNs to 0s, which is what they
|
||||
# should be out of the softmax.
|
||||
nan_mask = torch.isnan(p)
|
||||
p[nan_mask == 1] = 0
|
||||
ref_out = torch.einsum("bhqk,bhkd->bhqd", p.to(self.dtype), v)
|
||||
# compare
|
||||
if self.input_metadata.layout == "bshd":
|
||||
ref_out = ref_out.transpose(1, 2).clone()
|
||||
return ref_out
|
||||
|
||||
def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
|
||||
q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
|
||||
self.dtype
|
||||
)
|
||||
k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
|
||||
self.dtype
|
||||
)
|
||||
v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
|
||||
self.dtype
|
||||
)
|
||||
result = self.fwd(q, k, v)
|
||||
if self.input_metadata.o_scale is not None:
|
||||
result, _ = scale_fp8(result, self.input_metadata.o_scale)
|
||||
return result
|
||||
|
||||
def fwd_fp8_kv(self, q, k_quantized, v_quantized):
|
||||
k_descale, v_descale = (
|
||||
self.input_metadata.k_descale,
|
||||
self.input_metadata.v_descale,
|
||||
)
|
||||
k_dequantized = (
|
||||
k_quantized.to(torch.float32) * k_descale.to(torch.float32)
|
||||
).to(self.dtype)
|
||||
v_dequantized = (
|
||||
v_quantized.to(torch.float32) * v_descale.to(torch.float32)
|
||||
).to(self.dtype)
|
||||
return self.fwd(q, k_dequantized, v_dequantized)
|
||||
|
||||
def varlen_fwd(self, q, k, v, is_mqa=False):
|
||||
ref_out = torch.empty_like(q)
|
||||
if is_mqa:
|
||||
# Make KV look like HQ/HK "groups" of HK. Later, we will reshape so
|
||||
# the size aligns with Q.
|
||||
k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(
|
||||
-1, -1, self.HQ // self.HK, -1
|
||||
)
|
||||
v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(
|
||||
-1, -1, self.HQ // self.HK, -1
|
||||
)
|
||||
else:
|
||||
k_ref = k
|
||||
v_ref = v
|
||||
|
||||
for i in range(0, self.input_metadata.num_contexts):
|
||||
start_q, start_k = (
|
||||
self.input_metadata.cu_seqlens_q[i],
|
||||
self.input_metadata.cu_seqlens_k[i],
|
||||
)
|
||||
end_q, end_k = (
|
||||
self.input_metadata.cu_seqlens_q[i + 1],
|
||||
self.input_metadata.cu_seqlens_k[i + 1],
|
||||
)
|
||||
k_curr = k_ref[start_k:end_k]
|
||||
v_curr = v_ref[start_k:end_k]
|
||||
if is_mqa:
|
||||
k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3])
|
||||
v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3])
|
||||
scores = torch.einsum("qhd,khd->qhk", q[start_q:end_q], k_curr).float()
|
||||
p = torch.softmax(scores * self.input_metadata.sm_scale, dim=-1).half()
|
||||
ref_out[start_q:end_q] = torch.einsum("qhk,khd->qhd", p, v_curr)
|
||||
return ref_out
|
||||
|
||||
|
||||
def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False):
|
||||
q_descale = None
|
||||
if not fp8_kv:
|
||||
q, q_descale = scale_fp8(q)
|
||||
k, k_descale = scale_fp8(k)
|
||||
v, v_descale = scale_fp8(v)
|
||||
|
||||
# In real world use case, the p scale would be a parameter trained by the
|
||||
# model.
|
||||
p_scale = None
|
||||
|
||||
o_scale = torch.rand(1, device="cuda", requires_grad=False) if use_o_scale else None
|
||||
|
||||
return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale
|
||||
|
||||
|
||||
def input_helper(
|
||||
Z,
|
||||
HQ,
|
||||
HK,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
D_HEAD,
|
||||
dtype,
|
||||
layout=None,
|
||||
use_alibi=None,
|
||||
causal=None,
|
||||
is_fp8=False,
|
||||
fp8_kv=False,
|
||||
use_o_scale=False,
|
||||
use_bias=False,
|
||||
):
|
||||
assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
# Initialize q, k, v
|
||||
if layout == "bhsd":
|
||||
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
|
||||
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
|
||||
elif layout == "bshd":
|
||||
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
|
||||
k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
|
||||
|
||||
if use_alibi:
|
||||
# for n heads the set of slopes is the geometric sequence that starts
|
||||
# 2^(-8/n)
|
||||
alibi_slopes = torch.tensor(
|
||||
[2 ** (-8 / HQ * i) for i in range(1, HQ + 1)],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
).repeat(Z, 1)
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
if use_bias:
|
||||
bias = torch.randn(
|
||||
(1, HQ, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda", requires_grad=False
|
||||
)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
|
||||
k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
|
||||
v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
|
||||
|
||||
if is_fp8:
|
||||
(q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale) = quantize_input(
|
||||
q, k, v, use_o_scale=use_o_scale, fp8_kv=fp8_kv
|
||||
)
|
||||
else:
|
||||
q_descale = k_descale = v_descale = p_scale = o_scale = None
|
||||
|
||||
input_metadata = MetaData(
|
||||
sm_scale=D_HEAD**-0.5,
|
||||
max_seqlens_q=N_CTX_Q,
|
||||
max_seqlens_k=N_CTX_K,
|
||||
layout=layout,
|
||||
alibi_slopes=alibi_slopes,
|
||||
alibi_batch=Z,
|
||||
alibi_nheads=HQ,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
p_scale=p_scale,
|
||||
o_scale=o_scale,
|
||||
bias=bias,
|
||||
seqlen_q=N_CTX_Q,
|
||||
seqlen_k=N_CTX_K,
|
||||
)
|
||||
return q, k, v, input_metadata
|
||||
|
||||
|
||||
def varlen_input_helper(
|
||||
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
# Random sequence lengths. Using N_CTX as kind of max of sum of individual
|
||||
# seqs
|
||||
if not equal_seqlens:
|
||||
max_seqlens_q = N_CTX_Q // Z
|
||||
max_seqlens_k = N_CTX_K // Z
|
||||
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
|
||||
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
|
||||
else:
|
||||
seqlens_q = torch.full((Z,), N_CTX_Q // Z)
|
||||
seqlens_k = torch.full((Z,), N_CTX_K // Z)
|
||||
|
||||
# Calculate cumulative sequence lengths
|
||||
cu_seqlens_q = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
seqlens_q.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
cu_seqlens_k = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
seqlens_k.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_q.to(device="cuda")
|
||||
cu_seqlens_k = cu_seqlens_k.to(device="cuda")
|
||||
|
||||
# Initialize q, k, v with variable lengths
|
||||
total_q = cu_seqlens_q[-1].item()
|
||||
total_k = cu_seqlens_k[-1].item()
|
||||
q = (
|
||||
torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
k = (
|
||||
torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
v = (
|
||||
torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
sm_scale = D_HEAD**-0.5
|
||||
input_metadata = MetaData(sm_scale=sm_scale)
|
||||
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
|
||||
return q, k, v, input_metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD",
|
||||
[
|
||||
(1, 48, 12, 1, 1, 64),
|
||||
(4, 4, 4, 128, 128, 65),
|
||||
(16, 48, 48, 1, 1, 128),
|
||||
(64, 48, 24, 3, 3, 128),
|
||||
(4, 4, 4, 113, 123, 1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
@pytest.mark.parametrize("use_alibi", [True, False])
|
||||
@pytest.mark.parametrize("layout", ["bshd"])
|
||||
def test_op_fwd(
|
||||
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
q, k, v, input_metadata = input_helper(
|
||||
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, use_alibi, causal
|
||||
)
|
||||
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# triton implementation
|
||||
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
|
||||
|
||||
# Transpose here if layout is bshd so we have same reference code for all
|
||||
# layouts
|
||||
if layout == "bshd":
|
||||
q = q.transpose(1, 2).clone()
|
||||
k = k.transpose(1, 2).clone()
|
||||
v = v.transpose(1, 2).clone()
|
||||
# Replicate K and V if using MQA/GQA
|
||||
if HQ != HK:
|
||||
k = (
|
||||
k.view(k.shape[0], k.shape[1], -1, k.shape[2], k.shape[3])
|
||||
.expand(-1, -1, HQ // HK, -1, -1)
|
||||
.reshape(k.shape[0], -1, k.shape[2], k.shape[3])
|
||||
)
|
||||
v = (
|
||||
v.view(v.shape[0], v.shape[1], -1, v.shape[2], v.shape[3])
|
||||
.expand(-1, -1, HQ // HK, -1, -1)
|
||||
.reshape(v.shape[0], -1, v.shape[2], v.shape[3])
|
||||
)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.fwd(q, k, v)
|
||||
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
|
||||
[
|
||||
(4, 48, 1, 1, 64),
|
||||
(4, 48, 1, 1, 128),
|
||||
(4, 48, 3, 3, 128),
|
||||
(4, 4, 128, 128, 65),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
@pytest.mark.parametrize("layout", ["bhsd"])
|
||||
@pytest.mark.parametrize("use_o_scale", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="Triton FP8 requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_op_fwd_fp8(
|
||||
Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, use_o_scale, dtype=torch.float32
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
# Disable grad to save memory it won't run into OOM on CI machine.
|
||||
# q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD,
|
||||
# dtype, layout)
|
||||
|
||||
q_quantized, k_quantized, v_quantized, input_metadata = input_helper(
|
||||
Z,
|
||||
H,
|
||||
H,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
D_HEAD,
|
||||
dtype,
|
||||
causal=causal,
|
||||
layout=layout,
|
||||
is_fp8=True,
|
||||
use_o_scale=use_o_scale,
|
||||
)
|
||||
|
||||
o = torch.empty_like(q_quantized) if use_o_scale else None
|
||||
|
||||
tri_out, _ = triton_attention_rocm(
|
||||
q_quantized, k_quantized, v_quantized, o, input_metadata
|
||||
)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized)
|
||||
|
||||
# compare
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(torch.float32), tri_out.to(torch.float32), atol=7e-2, rtol=2e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
|
||||
[
|
||||
(4, 48, 1, 1, 64),
|
||||
(4, 48, 1, 1, 128),
|
||||
(4, 48, 3, 3, 128),
|
||||
(4, 4, 128, 128, 65),
|
||||
(4, 4, 113, 123, 1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
@pytest.mark.parametrize("layout", ["bhsd"])
|
||||
def test_op_fwd_fp8_kv(
|
||||
Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float32
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
q, k_quantized, v_quantized, input_metadata = input_helper(
|
||||
Z,
|
||||
H,
|
||||
H,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
D_HEAD,
|
||||
dtype,
|
||||
causal=causal,
|
||||
layout=layout,
|
||||
is_fp8=True,
|
||||
fp8_kv=True,
|
||||
)
|
||||
|
||||
o = torch.empty_like(q)
|
||||
|
||||
tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, input_metadata)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized)
|
||||
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
|
||||
[
|
||||
(4, 48, 1, 1, 64),
|
||||
(4, 48, 1, 1, 128),
|
||||
(4, 48, 3, 3, 128),
|
||||
(4, 4, 128, 128, 65),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
|
||||
current_platform.seed_everything(0)
|
||||
q, k, v, input_metadata = input_helper(
|
||||
Z,
|
||||
H,
|
||||
H,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
D_HEAD,
|
||||
dtype,
|
||||
layout="bhsd",
|
||||
causal=causal,
|
||||
use_bias=use_bias,
|
||||
)
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# triton implementation
|
||||
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.fwd(q, k, v)
|
||||
|
||||
# compare
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
|
||||
|
||||
|
||||
# NOTE: Uses thd layout, so also tests thd.
|
||||
@pytest.mark.parametrize(
|
||||
"Z, H, N_CTX, D_HEAD",
|
||||
[(1, 48, 256, 64), (4, 48, 512, 64), (16, 48, 512, 64), (64, 48, 128, 128)],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype)
|
||||
|
||||
tri_out = torch.empty_like(q)
|
||||
triton_attention_rocm(q, k, v, tri_out, input_metadata)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False)
|
||||
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
|
||||
|
||||
|
||||
# NOTE: Uses thd layout, so also tests thd.
|
||||
@pytest.mark.parametrize(
|
||||
"Z, HQ, HK, N_CTX, D_HEAD",
|
||||
[
|
||||
(2, 48, 24, 128, 64),
|
||||
(4, 48, 12, 256, 64),
|
||||
(4, 48, 4, 512, 64),
|
||||
(4, 64, 16, 128, 128),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [False])
|
||||
def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
q, k, v, input_metadata = varlen_input_helper(
|
||||
Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype
|
||||
)
|
||||
|
||||
tri_out = torch.empty_like(q)
|
||||
triton_attention_rocm(q, k, v, tri_out, input_metadata)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True)
|
||||
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
|
||||
@ -38,7 +38,11 @@ AITER_MODEL_LIST = [
|
||||
[
|
||||
pytest.param(
|
||||
"bigscience/bloom-560m", # bloom - testing alibi slopes
|
||||
marks=[pytest.mark.core_model, pytest.mark.slow_test],
|
||||
marks=[
|
||||
pytest.mark.core_model,
|
||||
pytest.mark.slow_test,
|
||||
pytest.mark.cpu_model,
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
"openai-community/gpt2", # gpt2
|
||||
@ -55,6 +59,10 @@ AITER_MODEL_LIST = [
|
||||
pytest.mark.slow_test,
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
"google/gemma-2-2b-it", # test hybrid attention
|
||||
marks=[pytest.mark.cpu_model],
|
||||
),
|
||||
pytest.param(
|
||||
"zai-org/chatglm3-6b", # chatglm (text-only)
|
||||
),
|
||||
@ -64,7 +72,6 @@ AITER_MODEL_LIST = [
|
||||
),
|
||||
pytest.param(
|
||||
"openbmb/MiniCPM3-4B",
|
||||
# fused_moe not supported on CPU
|
||||
marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
pytest.param(
|
||||
@ -93,11 +100,7 @@ AITER_MODEL_LIST = [
|
||||
pytest.param("bigcode/starcoder2-3b"), # starcoder2
|
||||
pytest.param(
|
||||
"TitanML/tiny-mixtral", # mixtral
|
||||
marks=[pytest.mark.core_model],
|
||||
),
|
||||
pytest.param(
|
||||
"allenai/OLMoE-1B-7B-0924-Instruct",
|
||||
marks=[pytest.mark.cpu_model],
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus
|
||||
],
|
||||
|
||||
@ -27,13 +27,7 @@ def test_models(
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
import pytest
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import check_embeddings_close
|
||||
|
||||
@ -23,8 +22,7 @@ from ...utils import check_embeddings_close
|
||||
),
|
||||
pytest.param(
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
# CPU v1 doesn't support sliding window
|
||||
marks=[pytest.mark.core_model],
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
pytest.param(
|
||||
"ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model]
|
||||
@ -52,13 +50,7 @@ def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
vllm_extra_kwargs = {}
|
||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||
vllm_extra_kwargs["pooler_config"] = PoolerConfig(
|
||||
|
||||
@ -2,18 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def test_idefics_multimodal(
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -59,13 +52,7 @@ def update_config(config):
|
||||
|
||||
def test_gemma_multimodal(
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
||||
@ -76,7 +76,6 @@ def test_prm_models(
|
||||
math_step_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
check_transformers_version(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53.2"
|
||||
@ -85,11 +84,6 @@ def test_prm_models(
|
||||
if current_platform.is_cpu():
|
||||
pytest.skip("CPU only supports V1")
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.reward(math_step_prompts)
|
||||
|
||||
|
||||
@ -5,16 +5,17 @@ image, embedding, and video support for different VLMs in vLLM.
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import PosixPath
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForTextToWaveform,
|
||||
)
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.func_utils import identity
|
||||
@ -38,13 +39,6 @@ from .vlm_utils.types import (
|
||||
VLMTestType,
|
||||
)
|
||||
|
||||
# This hack is needed for phi3v & paligemma models
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
COMMON_BROADCAST_SETTINGS = {
|
||||
"test_type": VLMTestType.IMAGE,
|
||||
"dtype": "half",
|
||||
@ -859,6 +853,12 @@ VLM_TEST_SETTINGS = {
|
||||
limit_mm_per_prompt={"image": 4},
|
||||
)
|
||||
],
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_VERSION) == Version("4.57.1"),
|
||||
reason="This model is broken in Transformers v4.57.1",
|
||||
)
|
||||
],
|
||||
),
|
||||
# regression test for https://github.com/vllm-project/vllm/issues/15122
|
||||
"qwen2_5_vl-windows-attention": VLMTestInfo(
|
||||
|
||||
@ -11,7 +11,6 @@ from huggingface_hub import snapshot_download
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import (
|
||||
IMAGE_ASSETS,
|
||||
@ -46,12 +45,6 @@ models = [model_path]
|
||||
|
||||
target_dtype = "half"
|
||||
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
|
||||
@ -14,7 +14,6 @@ from vllm.assets.image import ImageAsset
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.image import convert_image_mode, rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import (
|
||||
IMAGE_ASSETS,
|
||||
@ -68,12 +67,6 @@ def vllm_to_hf_output(
|
||||
|
||||
target_dtype = "half"
|
||||
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
|
||||
@ -61,10 +61,8 @@ def test_qwen2_5_vl_evs_functionality(
|
||||
model,
|
||||
runner="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"video": 1},
|
||||
tensor_parallel_size=1,
|
||||
video_pruning_rate=video_pruning_rate,
|
||||
) as vllm_model:
|
||||
# Generate output - this should not crash
|
||||
|
||||
80
tests/models/quantization/test_gpt_oss_attn_quantization.py
Normal file
80
tests/models/quantization/test_gpt_oss_attn_quantization.py
Normal file
@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test attention quantization of gpt-oss model.
|
||||
The qkv_proj and o_proj in self_attention can be either quantized or excluded.
|
||||
|
||||
Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`.
|
||||
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
|
||||
import huggingface_hub
|
||||
import lm_eval
|
||||
import pytest
|
||||
from packaging import version
|
||||
|
||||
MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"]
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
|
||||
def has_huggingface_access(repo):
|
||||
try:
|
||||
huggingface_hub.list_repo_refs(repo)
|
||||
return True
|
||||
except huggingface_hub.errors.RepositoryNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
HF_HUB_AMD_ORG_ACCESS = all(
|
||||
[has_huggingface_access(model_name) for model_name in MODEL_NAMES]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCase:
|
||||
model_id: str
|
||||
tp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationConfig:
|
||||
model_name: str
|
||||
|
||||
def get_model_args(self) -> str:
|
||||
return (
|
||||
f"pretrained={self.model_name},"
|
||||
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False"
|
||||
)
|
||||
|
||||
|
||||
EXPECTED_ACCURACIES = {"arc_challenge": 0.20}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.skipif(
|
||||
not HF_HUB_AMD_ORG_ACCESS,
|
||||
reason="Read access to huggingface.co/amd is required for this test.",
|
||||
)
|
||||
@pytest.mark.parametrize("model_name", MODEL_NAMES)
|
||||
@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items())
|
||||
def test_gpt_oss_attention_quantization(
|
||||
model_name: str, task_name: str, expected_accuracy: float
|
||||
):
|
||||
measured_accuracy = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=EvaluationConfig(model_name).get_model_args(),
|
||||
tasks=task_name,
|
||||
batch_size="auto",
|
||||
)["results"][task_name]["acc,none"]
|
||||
|
||||
rtol = 0.05
|
||||
assert (
|
||||
measured_accuracy - rtol < expected_accuracy
|
||||
and measured_accuracy + rtol > expected_accuracy
|
||||
), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}"
|
||||
@ -243,7 +243,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
|
||||
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo(
|
||||
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
|
||||
),
|
||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
|
||||
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
|
||||
|
||||
@ -93,6 +93,17 @@ def can_initialize(
|
||||
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
|
||||
)
|
||||
|
||||
if model_arch == "DeepseekV32ForCausalLM":
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability and capability.major < 9:
|
||||
pytest.skip(
|
||||
f"DeepseekV32 requires Hopper (9.0+) or Blackwell (10.0+) "
|
||||
f"for FLASHMLA_SPARSE backend. Current device has compute "
|
||||
f"capability {capability.major}.{capability.minor}"
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
|
||||
monkeypatch.context() as m,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user