diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 12f730738b8a5..38c400ba1faf5 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -132,7 +132,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 58fd435691f4a..0e5b21ddf25b3 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -78,17 +78,13 @@ HF_MOUNT="/root/.cache/huggingface" commands=$@ echo "Commands:$commands" -if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then - commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"} -fi +commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"pytest -v -s basic_correctness/test_basic_correctness.py"} if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} fi -if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then - commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} -fi +commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"pytest -v -s compile/test_basic_correctness.py"} if [[ $commands == *"pytest -v -s lora"* ]]; then commands=${commands//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"} diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 7927aef19e4eb..7479c43977d78 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -49,6 +49,7 @@ function cpu_tests() { # Run kernel tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e + pytest -x -v -s tests/kernels/attention/test_cpu_attn.py pytest -x -v -s tests/kernels/test_onednn.py" # Run basic model test @@ -76,7 +77,7 @@ function cpu_tests() { # Run AWQ test # docker exec cpu-test-"$NUMA_NODE" bash -c " # set -e - # VLLM_USE_V1=0 pytest -x -s -v \ + # pytest -x -s -v \ # tests/quantization/test_ipex_quant.py" # Run multi-lora tests @@ -116,4 +117,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index bb5ef5d624630..5fd048c2ad0c6 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -348,6 +348,7 @@ steps: - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py + - pytest -v -s v1/test_outputs.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a0d2076199b14..be1b79ddc4324 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -25,6 +25,7 @@ # and $$BUILDKITE_PARALLEL_JOB_COUNT environment variables. # working_dir(str): specify the place where the command should execute, default to /vllm-workspace/tests # source_file_dependencies(list): the list of prefixes to opt-in the test for, if empty, the test will always run. +# autorun_on_main (bool): default to false, if true, the test will run automatically when commit is pushed to main branch. # When adding a test # - If the test belongs to an existing group, add it there @@ -56,7 +57,7 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker Test (CPU) # 4 mins +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins timeout_in_minutes: 10 source_file_dependencies: - vllm/ @@ -65,6 +66,7 @@ steps: - tests/multimodal - tests/standalone_tests/lazy_imports.py - tests/transformers_utils + - tests/config no_gpu: true commands: - python3 standalone_tests/lazy_imports.py @@ -72,6 +74,7 @@ steps: - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal - pytest -v -s transformers_utils + - pytest -v -s config - label: Python-only Installation Test # 10min timeout_in_minutes: 20 @@ -329,6 +332,7 @@ steps: - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py + - pytest -v -s v1/test_outputs.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine @@ -441,6 +445,7 @@ steps: - vllm/ - tests/compile commands: + - pytest -v -s compile/test_config.py - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py - pytest -v -s compile/test_fusion_attn.py @@ -450,6 +455,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_qk_norm_rope_fusion.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 @@ -604,6 +610,7 @@ steps: source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization + autorun_on_main: true commands: - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 @@ -867,12 +874,12 @@ steps: optional: true commands: - pip install --upgrade git+https://github.com/huggingface/transformers - - pytest -v -s tests/models/test_initialization.py + - pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)' - pytest -v -s tests/models/test_transformers.py - - pytest -v -s tests/models/multimodal/processing/ - - pytest -v -s tests/models/multimodal/test_mapping.py + # - pytest -v -s tests/models/multimodal/processing/ + - pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)' - python3 examples/offline_inference/basic/chat.py - - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl + # - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper @@ -890,11 +897,16 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/attention/backends/mla/cutlass_mla.py + - vllm/v1/attention/backends/mla/flashinfer_mla.py + - vllm/platforms/cuda.py + - vllm/attention/selector.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py # Attention # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_attention_selector.py - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py @@ -932,7 +944,7 @@ steps: # this runner has 2 GPUs available even though num_gpus=2 is not set - pytest -v -s tests/compile/test_fusion_all_reduce.py # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time - # Wrap with quotes to escape yaml + # Wrap with quotes to escape yaml - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'" - label: Blackwell Fusion E2E Tests # 30 min diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 23def076cf880..f26c782bccf2c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -61,6 +61,16 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /vllm/model_executor/models/transformers @hmellor /tests/models/test_transformers.py @hmellor +# Observability +/vllm/config/observability.py @markmc +/vllm/v1/metrics @markmc +/tests/v1/metrics @markmc +/vllm/tracing.py @markmc +/tests/v1/tracing/test_tracing.py @markmc +/vllm/config/kv_events.py @markmc +/vllm/distributed/kv_events.py @markmc +/tests/distributed/test_events.py @markmc + # Docs /docs/mkdocs @hmellor /docs/**/*.yml @hmellor diff --git a/.github/mergify.yml b/.github/mergify.yml index 18d4a2e83144b..997a40e18e588 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -151,6 +151,23 @@ pull_request_rules: add: - gpt-oss +- name: label-nvidia + description: Automatically apply nvidia label + conditions: + - label != stale + - or: + - files~=cuda + - files~=cutlass + - files~=flashinfer + - files~=trtllm + - title~=(?i)NVIDIA + - title~=(?i)CUDA + - title~=(?i)CUTLASS + actions: + label: + add: + - nvidia + - name: label-rocm description: Automatically apply rocm label conditions: diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e9fa63b178ea..dcc44be87e557 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,13 @@ set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151") +# ROCm installation prefix. Default to /opt/rocm but allow override via +# -DROCM_PATH=/your/rocm/path when invoking cmake. +if(NOT DEFINED ROCM_PATH) + set(ROCM_PATH "/opt/rocm" CACHE PATH "ROCm installation prefix") +else() + set(ROCM_PATH ${ROCM_PATH} CACHE PATH "ROCm installation prefix" FORCE) +endif() # # Supported/expected torch versions for CUDA/ROCm. # @@ -237,10 +244,27 @@ set_gencode_flags_for_srcs( SRCS "${VLLM_CUMEM_EXT_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") -if(VLLM_GPU_LANG STREQUAL "CUDA") +if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling cumem allocator extension.") - # link against cuda driver library - list(APPEND CUMEM_LIBS CUDA::cuda_driver) + if(VLLM_GPU_LANG STREQUAL "CUDA") + # link against cuda driver library + list(APPEND CUMEM_LIBS CUDA::cuda_driver) + else() + # link against rocm driver library. Prefer an absolute path to + # libamdhip64.so inside ${ROCM_PATH}/lib if available, otherwise fall + # back to linking by name "amdhip64". + find_library(AMDHIP64_LIB + NAMES amdhip64 libamdhip64.so + PATHS ${ROCM_PATH}/lib + NO_DEFAULT_PATH) + if(AMDHIP64_LIB) + message(STATUS "Found libamdhip64 at ${AMDHIP64_LIB}") + list(APPEND CUMEM_LIBS ${AMDHIP64_LIB}) + else() + message(WARNING "libamdhip64 not found in ${ROCM_PATH}/lib; falling back to linking 'amdhip64' by name") + list(APPEND CUMEM_LIBS amdhip64) + endif() + endif() define_extension_target( cumem_allocator DESTINATION vllm @@ -265,6 +289,7 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/cuda_view.cu" @@ -330,7 +355,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_ARCHS) # @@ -914,7 +939,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) # diff --git a/README.md b/README.md index b5e230e4b9b07..033e1035d8916 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio *Latest News* 🔥 +- [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI) - [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link). - [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6). - [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing). diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index f1e504499eaf6..11e3ac7f0c1fa 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -1,10 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +# Disable DeepGEMM for this benchmark to use CUTLASS +os.environ["VLLM_USE_DEEP_GEMM"] = "0" + import torch from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear, + W8A8BlockFp8LinearOp, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, @@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min - # Create random FP8 tensors + # Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp) A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + # Create quantized weight tensor B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - # Create scales + # Create weight scales block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k @@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): * factor_for_scale ) - # SM90 CUTLASS requires row-major format for scales - if use_cutlass and current_platform.is_device_capability(90): - Bs = Bs.T.contiguous() + # Create W8A8BlockFp8LinearOp instance + weight_group_shape = GroupShape(block_n, block_k) + act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization + + linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=weight_group_shape, + act_quant_group_shape=act_quant_group_shape, + cutlass_block_fp8_supported=use_cutlass, + use_aiter_and_is_supported=False, + ) def run(): - if use_cutlass: - return apply_w8a8_block_fp8_linear( - A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True - ) - else: - return apply_w8a8_block_fp8_linear( - A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False - ) + return linear_op.apply( + input=A_ref, + weight=B, + weight_scale=Bs, + input_scale=None, + bias=None, + ) return run diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py index 2674899d1cc56..8cb8a2f386a97 100644 --- a/benchmarks/multi_turn/bench_dataset.py +++ b/benchmarks/multi_turn/bench_dataset.py @@ -11,6 +11,7 @@ from bench_utils import ( Color, logger, ) +from tqdm import tqdm from transformers import AutoTokenizer # type: ignore # Conversation ID is a string (e.g: "UzTK34D") @@ -417,6 +418,10 @@ def generate_conversations( data = file.read() tokens_in_file = tokenizer.encode(data, add_special_tokens=False) list_of_tokens.extend(tokens_in_file) + logger.info( + f"Loaded {len(tokens_in_file)} tokens from file {filename}, " + f"total tokens so far: {len(list_of_tokens)}" + ) conversations: ConversationsMap = {} conv_id = 0 @@ -449,18 +454,25 @@ def generate_conversations( ) base_offset += common_prefix_tokens - for conv_id in range(args.num_conversations): + for conv_id in tqdm( + range(args.num_conversations), + total=args.num_conversations, + desc="Generating conversations", + unit="conv", + ): # Generate a single conversation messages: MessagesList = [] nturns = turn_count[conv_id] # User prompt token count per turn (with lower limit) - input_token_count: np.ndarray = args.input_num_tokens.sample(nturns) + input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int) input_token_count = np.maximum(input_token_count, base_prompt_token_count) # Assistant answer token count per turn (with lower limit) - output_token_count: np.ndarray = args.output_num_tokens.sample(nturns) + output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype( + int + ) output_token_count = np.maximum(output_token_count, 1) user_turn = True diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index 5d2ac66e5ab94..ae9e9753441aa 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -55,6 +55,7 @@ class ClientArgs(NamedTuple): verify_output: bool conversation_sampling: ConversationSampling request_rate: float + max_retries: int class RequestArgs(NamedTuple): @@ -63,6 +64,7 @@ class RequestArgs(NamedTuple): stream: bool limit_min_tokens: int # Use negative value for no limit limit_max_tokens: int # Use negative value for no limit + timeout_sec: int class BenchmarkArgs(NamedTuple): @@ -214,6 +216,7 @@ async def send_request( stream: bool = True, min_tokens: int | None = None, max_tokens: int | None = None, + timeout_sec: int = 120, ) -> ServerResponse: payload = { "model": model, @@ -235,10 +238,16 @@ async def send_request( headers = {"Content-Type": "application/json"} # Calculate the timeout for the request - timeout_sec = 120 if max_tokens is not None: # Assume TPOT of 200ms and use max_tokens to determine timeout - timeout_sec = max(timeout_sec, int(max_tokens * 0.2)) + token_based_timeout = int(max_tokens * 0.2) + if token_based_timeout > timeout_sec: + timeout_sec = token_based_timeout + logger.info( + "Using timeout of %ds based on max_tokens %d", + timeout_sec, + max_tokens, + ) timeout = aiohttp.ClientTimeout(total=timeout_sec) valid_response = True @@ -409,6 +418,7 @@ async def send_turn( req_args.stream, min_tokens, max_tokens, + req_args.timeout_sec, ) if response.valid is False: @@ -518,6 +528,25 @@ async def poisson_sleep(request_rate: float, verbose: bool = False) -> None: await asyncio.sleep(interval) +async def exponential_backoff_sleep( + attempt_cnt: int, + base_rate: float = 1.0, + backoff_factor: float = 2.0, + jitter_fraction: float = 0.10, + verbose: bool = False, +) -> None: + # Sleep with exponential backoff and jitter after a failed request. + backoff_delay = base_rate * (backoff_factor**attempt_cnt) + jittered_delay = backoff_delay * ( + 1 + np.random.uniform(-jitter_fraction, jitter_fraction) + ) + + if verbose: + logger.info(f"Backoff for {jittered_delay:.3f} seconds...") + + await asyncio.sleep(jittered_delay) + + async def client_main( args: ClientArgs, req_args: RequestArgs, @@ -646,49 +675,62 @@ async def client_main( ) time_of_last_turn[conv_id] = curr_time_sec - success = True - try: - result = await send_turn( - session, - client_id, - conv_id, - messages, - current_turn, - tokenizer, - req_args, - args.print_content, - args.verify_output, - ) - if result is not None: - result_queue.put(result) - else: - # None means that the request failed, - # and should not be added to the statistics. - success = False - num_failures += 1 - - logger.warning( - f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + success = False + for attempt_cnt in range(args.max_retries + 1): + try: + exception = False + result = await send_turn( + session, + client_id, + conv_id, + messages, + current_turn, + tokenizer, + req_args, + args.print_content, + args.verify_output, + ) + if result is not None: + result_queue.put(result) + success = True + break + else: + logger.warning( + f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + except asyncio.exceptions.TimeoutError: + exception = True + logger.error( + "%sClient %d - Timeout during conversation ID %s (turn: %d). " + "Base timeout is %ss (set with --request-timeout-sec), but the " + "effective timeout may be longer based on max_tokens. If this " + "is unexpected, consider increasing the timeout or checking " + "model performance.%s", + Color.RED, + client_id, + conv_id, + current_turn, + req_args.timeout_sec, + Color.RESET, + ) + except Exception: + exception = True + logger.exception( + f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 ) - # Remove the conversation (should not be used again) - active_convs.pop(conv_id) + # Sleep before retry if not last attempt + if not success and attempt_cnt < args.max_retries: + await exponential_backoff_sleep(attempt_cnt, verbose=args.verbose) - except asyncio.exceptions.TimeoutError: + if not success: num_failures += 1 - logger.exception( - f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 - ) - break # Exit gracefully instead of raising an error + # Remove the conversation (should not be used again) + active_convs.pop(conv_id) + if exception: + break # Exit gracefully instead of raising an error - except Exception: - num_failures += 1 - logger.exception( - f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 - ) - break # Exit gracefully instead of raising an error - - if success: + else: num_successes += 1 # Update the turns counter to include the LLM response @@ -803,6 +845,7 @@ def get_client_config( verify_output=args.verify_output, conversation_sampling=args.conversation_sampling, request_rate=args.request_rate, + max_retries=args.max_retries, ) if args.limit_min_tokens > 0 or args.limit_max_tokens > 0: @@ -815,6 +858,9 @@ def get_client_config( "Invalid min/max tokens limits (min should not be larger than max)" ) + if args.request_timeout_sec <= 0: + raise ValueError("Request timeout must be a positive number") + # Arguments for API requests chat_url = f"{args.url}/v1/chat/completions" model_name = args.served_model_name if args.served_model_name else args.model @@ -825,6 +871,7 @@ def get_client_config( stream=not args.no_stream, limit_min_tokens=args.limit_min_tokens, limit_max_tokens=args.limit_max_tokens, + timeout_sec=args.request_timeout_sec, ) return client_args, req_args @@ -968,7 +1015,7 @@ async def main_mp( f"(is alive: {client.is_alive()}){Color.RESET}" ) - client.join(timeout=120) + client.join(timeout=req_args.timeout_sec + 1) if client.is_alive(): logger.warning( @@ -1334,6 +1381,16 @@ async def main() -> None: help="Expected request rate (Poisson process) per client in requests/sec." "Set to 0 for no delay between requests.", ) + parser.add_argument( + "--max-retries", + type=int, + default=int(os.environ.get("MULTITURN_BENCH_MAX_RETRIES", "0")), + help="Maximum number of retry attempts for timed-out requests. " + "Default is 0 (no retries). " + "Set to higher values to retry failed requests and maintain " + "fair workload distribution. " + "Can also be set via MULTITURN_BENCH_MAX_RETRIES environment variable.", + ) parser.add_argument( "--conversation-sampling", type=ConversationSampling, @@ -1351,6 +1408,13 @@ async def main() -> None: action="store_true", help="Verify the LLM output (compare to the answers in the input JSON file)", ) + parser.add_argument( + "--request-timeout-sec", + type=int, + default=120, + help="Timeout in seconds for each API request (default: 120). " + "Automatically increased if max tokens imply longer decoding.", + ) parser.add_argument( "--no-stream", diff --git a/benchmarks/multi_turn/requirements.txt b/benchmarks/multi_turn/requirements.txt index f0e1935914a14..bae656a5c5c4b 100644 --- a/benchmarks/multi_turn/requirements.txt +++ b/benchmarks/multi_turn/requirements.txt @@ -2,4 +2,5 @@ numpy>=1.24 pandas>=2.0.0 aiohttp>=3.10 transformers>=4.46 -xlsxwriter>=3.2.1 \ No newline at end of file +xlsxwriter>=3.2.1 +tqdm>=4.66 diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index dbda19fbcbf20..bb0179c79c108 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -15,6 +15,7 @@ endif() # set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) +set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16}) include_directories("${CMAKE_SOURCE_DIR}/csrc") @@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) set(ENABLE_AVX512VNNI OFF) message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") endif() + + find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND) + if (AMXBF16_FOUND OR ENABLE_AMXBF16) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile") + set(ENABLE_AMXBF16 ON) + add_compile_definitions(-DCPU_CAPABILITY_AMXBF16) + else() + set(ENABLE_AMXBF16 OFF) + message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3") + endif() + else() + set(ENABLE_AMXBF16 OFF) + message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.") + endif() elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") @@ -193,7 +210,30 @@ endif() if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) # Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64 # TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN + set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "") if(ASIMD_FOUND) + # Set number of parallel build processes + include(ProcessorCount) + ProcessorCount(NPROC) + if(NOT NPROC) + set(NPROC 4) + endif() + # locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0) + # and create a local shim dir with it + vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR) + + find_library(OPEN_MP + NAMES gomp + PATHS ${VLLM_TORCH_GOMP_SHIM_DIR} + NO_DEFAULT_PATH + REQUIRED + ) + # Set LD_LIBRARY_PATH to include the shim dir at build time to use the same libgomp as PyTorch + if (OPEN_MP) + set(ENV{LD_LIBRARY_PATH} "${VLLM_TORCH_GOMP_SHIM_DIR}:$ENV{LD_LIBRARY_PATH}") + endif() + + # Fetch and populate ACL if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}") message(STATUS "Using ACL from specified source directory: $ENV{ACL_ROOT_DIR}") else() @@ -207,38 +247,53 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON GIT_PROGRESS TRUE ) set(ENV{ACL_ROOT_DIR} "${arm_compute_SOURCE_DIR}") + set(ACL_LIB_DIR "$ENV{ACL_ROOT_DIR}/build") endif() - # Build ACL with scons - include(ProcessorCount) - ProcessorCount(_NPROC) - set(_scons_cmd - scons -j${_NPROC} - Werror=0 debug=0 neon=1 examples=0 embed_kernels=0 os=linux - arch=armv8.2-a build=native benchmark_examples=0 fixed_format_kernels=1 - multi_isa=1 openmp=1 cppthreads=0 + # Build ACL with CMake + set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF") + set(CMAKE_BUILD_TYPE "Release") + set(ARM_COMPUTE_ARCH "armv8.2-a") + set(ARM_COMPUTE_ENABLE_ASSERTS "OFF") + set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ARM_COMPUTE_ENABLE_OPENMP "ON") + set(ARM_COMPUTE_ENABLE_WERROR "OFF") + set(ARM_COMPUTE_BUILD_EXAMPLES "OFF") + set(ARM_COMPUTE_BUILD_TESTING "OFF") + + set(_cmake_config_cmd + ${CMAKE_COMMAND} -G Ninja -B build + -DARM_COMPUTE_BUILD_SHARED_LIB=OFF + -DCMAKE_BUILD_TYPE=Release + -DARM_COMPUTE_ARCH=armv8.2-a + -DARM_COMPUTE_ENABLE_ASSERTS=OFF + -DARM_COMPUTE_ENABLE_CPPTHREADS=OFF + -DARM_COMPUTE_ENABLE_OPENMP=ON + -DARM_COMPUTE_ENABLE_WERROR=OFF + -DARM_COMPUTE_BUILD_EXAMPLES=OFF + -DARM_COMPUTE_BUILD_TESTING=OFF) + set(_cmake_build_cmd + ${CMAKE_COMMAND} --build build -- -j${NPROC} ) - # locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0) - # and create a local shim dir with it - include("${CMAKE_CURRENT_LIST_DIR}/utils.cmake") - vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR) - - if(NOT VLLM_TORCH_GOMP_SHIM_DIR STREQUAL "") - list(APPEND _scons_cmd extra_link_flags=-L${VLLM_TORCH_GOMP_SHIM_DIR}) - endif() - execute_process( - COMMAND ${_scons_cmd} + COMMAND ${_cmake_config_cmd} + WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}" + ) + execute_process( + COMMAND ${_cmake_build_cmd} WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}" RESULT_VARIABLE _acl_rc ) + if(NOT _acl_rc EQUAL 0) message(FATAL_ERROR "ACL SCons build failed (exit ${_acl_rc}).") endif() + message(STATUS "Arm Compute Library (ACL) built successfully.") - set(ONEDNN_AARCH64_USE_ACL "ON") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") + # VLLM/oneDNN settings for ACL + set(ONEDNN_AARCH64_USE_ACL ON CACHE BOOL "" FORCE) add_compile_definitions(VLLM_USE_ACL) endif() @@ -275,7 +330,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) + set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size FetchContent_MakeAvailable(oneDNN) + set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE}) add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") target_include_directories( dnnl_ext @@ -305,14 +363,14 @@ endif() # set(VLLM_EXT_SRC "csrc/cpu/activation.cpp" - "csrc/cpu/attention.cpp" - "csrc/cpu/cache.cpp" "csrc/cpu/utils.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/torch_bindings.cpp" - "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/scratchpad_manager.cpp" + "csrc/cpu/torch_bindings.cpp") if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp deleted file mode 100644 index 82862fea7f2be..0000000000000 --- a/csrc/cpu/attention.cpp +++ /dev/null @@ -1,798 +0,0 @@ -#include "cpu_types.hpp" - -namespace { - -template -struct KernelVecType { - using q_load_vec_type = void; - using q_vec_type = void; - using k_load_vec_type = void; - using k_vec_type = void; - using qk_acc_vec_type = void; - using v_load_vec_type = void; -}; - -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::FP32Vec4; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::FP32Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP32Vec16; -}; - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power and s390x architecture-specific vector types - using q_load_vec_type = vec_op::FP32Vec8; - using k_load_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures, including x86 - using q_load_vec_type = vec_op::FP16Vec8; - using k_load_vec_type = vec_op::FP16Vec16; - using v_load_vec_type = vec_op::FP16Vec16; -#endif - using q_vec_type = vec_op::FP32Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; -}; - -#ifdef __AVX512BF16__ -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::BF16Vec32; - using k_load_vec_type = vec_op::BF16Vec32; - using k_vec_type = vec_op::BF16Vec32; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; -#else - #ifdef __aarch64__ - #ifndef ARM_BF16_SUPPORT - // pass - #else -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::BF16Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; - #endif - #else -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::BF16Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; - #endif -#endif - -template -FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, - const int capacity) { - T max = data[0]; - for (int i = 1; i < size; ++i) { - max = max >= data[i] ? max : data[i]; - } - - T sum = 0; - for (int i = 0; i < size; ++i) { - data[i] = std::exp(data[i] - max); - sum += data[i]; - } - - int i = 0; - for (; i < size; ++i) { - data[i] /= sum; - } - - for (; i < capacity; ++i) { - data[i] = 0; - } - - return {max, sum}; -} - -template -FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, - const int capacity, - const float alibi_slope, - const int start_index, - const int seq_len) { - data[0] += alibi_slope * (start_index - seq_len + 1); - T max = data[0]; - for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); - data[i] = qk; - max = max >= qk ? max : qk; - } - - T sum = 0; - for (int i = 0; i < size; ++i) { - data[i] = std::exp(data[i] - max); - sum += data[i]; - } - - int i = 0; - for (; i < size; ++i) { - data[i] /= sum; - } - - for (; i < capacity; ++i) { - data[i] = 0; - } - - return {max, sum}; -} - -template -FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data, - const int size) { - T max = max_data[0]; - for (int i = 1; i < size; ++i) { - max = max >= max_data[i] ? max : max_data[i]; - } - - T rescaled_sum = 0; - for (int i = 0; i < size; ++i) { - T rescale_factor = std::exp(max_data[i] - max); - rescaled_sum += rescale_factor * sum_data[i]; - sum_data[i] *= rescale_factor; - } - for (int i = 0; i < size; ++i) { - sum_data[i] /= rescaled_sum + 1e-8; - } -} - -template -struct reduceQKBlockKernel { - using q_load_vec_type = typename KernelVecType::q_load_vec_type; - using q_vec_type = typename KernelVecType::q_vec_type; - using k_load_vec_type = typename KernelVecType::k_load_vec_type; - using k_vec_type = typename KernelVecType::k_vec_type; - using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; - - constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; - constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; - constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; - - static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); - static_assert(k_load_vec_type::get_elem_num() % x == 0); - static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - - FORCE_INLINE static void call(const scalar_t* __restrict__ q, - const scalar_t* __restrict__ k_block, - float* __restrict__ logits, float scale, - const int token_num) { - const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; - - qk_acc_vec_type group_accums[MAX_GROUP_NUM]; - if (token_num == BLOCK_SIZE) { - for (int q_offset = 0; q_offset < HEAD_SIZE; - q_offset += x, k_block += x * BLOCK_SIZE) { - q_load_vec_type q_load_group_vec(q + q_offset); - q_vec_type q_group_vec(q_load_group_vec); - - vec_op::unroll_loop( - [k_block, &q_group_vec, &group_accums](int token_group_idx) { - k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * - TOKEN_PER_GROUP); - k_vec_type k_group_vec(k_load_group_vec); - vec_op::fma(group_accums[token_group_idx], q_group_vec, - k_group_vec); - vec_op::prefetch(k_block + x * BLOCK_SIZE + - token_group_idx * x * TOKEN_PER_GROUP); - }); - } - } else { - for (int q_offset = 0; q_offset < HEAD_SIZE; - q_offset += x, k_block += x * BLOCK_SIZE) { - q_load_vec_type q_load_group_vec(q + q_offset); - q_vec_type q_group_vec(q_load_group_vec); - for (int token_group_start = 0; token_group_start < group_num; - token_group_start += UNROLL_GROUP_NUM) { - vec_op::unroll_loop( - [token_group_start, k_block, &q_group_vec, - &group_accums](int token_group_idx) { - token_group_idx += token_group_start; - k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * - TOKEN_PER_GROUP); - k_vec_type k_group_vec(k_load_group_vec); - vec_op::fma(group_accums[token_group_idx], q_group_vec, - k_group_vec); - vec_op::prefetch(k_block + x * BLOCK_SIZE + - token_group_idx * x * TOKEN_PER_GROUP); - }); - } - } - } - - for (int token_group_idx = 0; token_group_idx < group_num; - ++token_group_idx) { - vec_op::unroll_loop( - [&group_accums, logits, scale, token_group_idx](int token_idx) { - float dot_v = - group_accums[token_group_idx] - .template reduce_sub_sum(token_idx); - logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = - dot_v * scale; - }); - } - } -}; - -template -FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, - acc_t&& acc) { - using v_load_vec_type = typename KernelVecType::v_load_vec_type; - constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); - static_assert(BLOCK_SIZE == ELEM_NUM); - vec_op::FP32Vec16 prob_vec(prob); - - vec_op::unroll_loop([&](int head_elem_idx) { - v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx); - vec_op::FP32Vec16 fp32_v_vec(v_vec); - acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; - }); -} -}; // namespace - -// Paged attention v1 -namespace { -template -struct paged_attention_v1_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { - constexpr int x = 16 / sizeof(scalar_t); - const int num_queries_per_kv = num_heads / num_kv_heads; - - static_assert(BLOCK_SIZE == 16); - - int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); - - const int parallel_work_item_num = omp_get_max_threads(); - - size_t logits_bytes = - parallel_work_item_num * max_seq_len_padded * sizeof(float); - float* logits = (float*)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] - -#pragma omp parallel for collapse(2) schedule(dynamic, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int seq_len = seq_lens[seq_idx]; - const int* seq_block_table = - block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = - q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; - float* __restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_seq_len_padded; - - // Compute logits - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = - k_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = - thread_block_logits + block_idx * BLOCK_SIZE; - - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); - } - - // Compute softmax - if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, seq_len, - block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - seq_len); - } else { - reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); - } - - // Compute value - constexpr int head_elem_num_per_partition = 16; - constexpr int head_partition_num = - HEAD_SIZE / head_elem_num_per_partition; - for (int head_part_idx = 0; head_part_idx < head_partition_num; - ++head_part_idx) { - vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + - head_part_idx * head_elem_num_per_partition; - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = - thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = - v_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - reduceValueBlock( - prob_vec_ptr, v_block_cache_ptr, accums); - - if (block_idx != block_num - 1) { - const int64_t next_physical_block_idx = - seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = - v_cache + next_physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - vec_op::unroll_loop( - [&](int head_elem_idx) { - if (head_elem_idx % 2 == 0) { - vec_op::prefetch(next_v_block_cache_ptr + - BLOCK_SIZE * head_elem_idx); - } - }); - } - } - - vec_op::unroll_loop( - [&](int head_elem_idx) { - float value = accums[head_elem_idx].reduce_sum(); - vec_op::storeFP32(value, out_ptr + head_elem_idx); - }); - } - } - } - std::free(logits); - } -}; - -#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v1_impl::call( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ - num_heads); - -template -void paged_attention_v1_impl_launcher( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - switch (head_size) { - case 32: - LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); - break; - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes); - -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } -} // namespace - -void paged_attention_v1( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) - CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) - }); -} - -// Paged attention v2 -namespace { -template -struct paged_attention_v2_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads, const int max_num_partitions) { - constexpr int x = 16 / sizeof(scalar_t); - const int num_queries_per_kv = num_heads / num_kv_heads; - - static_assert(BLOCK_SIZE == 16); - static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0); - static_assert(PARTITION_SIZE % BLOCK_SIZE == 0); - -#pragma omp parallel for collapse(3) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int partition_idx = 0; partition_idx < max_num_partitions; - ++partition_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seq_len = seq_lens[seq_idx]; - const int start_token_idx = partition_idx * PARTITION_SIZE; - - if (start_token_idx >= seq_len) continue; - - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - const bool no_reduce = (partition_num == 1); - const int token_num = - (std::min(seq_len, start_token_idx + PARTITION_SIZE) - - start_token_idx); - const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int last_block_token_num = - token_num - (block_num - 1) * BLOCK_SIZE; - const int* seq_block_table = block_tables + - max_num_blocks_per_seq * seq_idx + - start_token_idx / BLOCK_SIZE; - const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = - q + seq_idx * q_stride + head_idx * HEAD_SIZE; - - float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; - - // Compute logits - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = - k_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = - logits + block_idx * BLOCK_SIZE; - - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); - } - - std::pair max_and_sum; - if (alibi_slopes) { - max_and_sum = reduceSoftmaxAlibi( - logits, token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, seq_len); - } else { - max_and_sum = - reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); - } - - auto&& [max_logit, exp_sum] = max_and_sum; - - scalar_t* __restrict__ output_buffer = nullptr; - if (!no_reduce) { - auto idx = seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - max_logits[idx] = max_logit; - exp_sums[idx] = exp_sum; - output_buffer = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - } else { - output_buffer = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - } - - // Compute value - constexpr int head_elem_num_per_partition = 16; - constexpr int head_partition_num = - HEAD_SIZE / head_elem_num_per_partition; - for (int head_part_idx = 0; head_part_idx < head_partition_num; - ++head_part_idx) { - vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = - output_buffer + head_part_idx * head_elem_num_per_partition; - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = - logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = - v_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - reduceValueBlock( - prob_vec_ptr, v_block_cache_ptr, accums); - - if (block_idx != block_num - 1) { - const int64_t next_physical_block_idx = - seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = - v_cache + next_physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - vec_op::unroll_loop( - [&](int head_elem_idx) { - if (head_elem_idx % 2 == 0) { - vec_op::prefetch(next_v_block_cache_ptr + - BLOCK_SIZE * head_elem_idx); - } - }); - } - } - - vec_op::unroll_loop( - [&](int head_elem_idx) { - float value = accums[head_elem_idx].reduce_sum(); - vec_op::storeFP32(value, out_ptr + head_elem_idx); - }); - } - } - } - } - - // Rescale partition softmax and store the factors to exp_sums -#pragma omp parallel for collapse(2) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seq_len = seq_lens[seq_idx]; - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - - if (partition_num == 1) continue; - - reducePartitionSoftmax( - max_logits + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions, - exp_sums + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions, - partition_num); - } - } - - // Reduce values - using v_load_vec_type = typename KernelVecType::v_load_vec_type; - static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); - constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some - // HEAD_SIZE didn't align with 64 bytes - static_assert(HEAD_SIZE % head_elem_num_per_group == 0); - constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float* __restrict__ rescale_factors = exp_sums; -#pragma omp parallel for collapse(3) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int seq_len = seq_lens[seq_idx]; - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - - if (partition_num == 1) continue; - - const float* __restrict__ seq_head_rescale_factors = - rescale_factors + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; - const scalar_t* __restrict__ seq_head_tmp_out = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + - group_idx * head_elem_num_per_group; - scalar_t* __restrict__ seq_head_output = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + - group_idx * head_elem_num_per_group; - - vec_op::FP32Vec16 acc; - for (int i = 0; i < partition_num; ++i) { - vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]); - v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE); - vec_op::FP32Vec16 fp32_value(value); - acc = acc + fp32_value * rescale_factor; - } - v_load_vec_type cast_acc(acc); - cast_acc.save(seq_head_output); - } - } - } - } -}; - -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ - max_num_partitions); - -template -void paged_attention_v2_impl_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const std::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - int max_num_partitions = exp_sums.size(-1); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - switch (head_size) { - case 32: - LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); - break; - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ - alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } -} // namespace - -void paged_attention_v2( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) - CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) - }); -} \ No newline at end of file diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp deleted file mode 100644 index 69f6d06e3c967..0000000000000 --- a/csrc/cpu/cache.cpp +++ /dev/null @@ -1,214 +0,0 @@ -#include -#include - -#include "cpu_types.hpp" - -#if defined(__x86_64__) - #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2 -#else - #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES -#endif - -namespace { -template -void copy_blocks_cpu_impl(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, - const int layer_num) { - const size_t pair_num = mapping_pairs.size(0); - const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; -#pragma omp parallel for collapse(2) - for (int layer = 0; layer < layer_num; ++layer) { - for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = - element_num_per_block * mapping_pairs[pair][0].item(); - int64_t target_offset = - element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t* source_ptr = key_cache_ptr + source_offset; - scalar_t* target_ptr = key_cache_ptr + target_offset; - std::memcpy(target_ptr, source_ptr, block_bytes); - - scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); - source_ptr = value_cache_ptr + source_offset; - target_ptr = value_cache_ptr + target_offset; - std::memcpy(target_ptr, source_ptr, block_bytes); - } - } -} - -template -void reshape_and_cache_cpu_impl( - const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const int64_t* __restrict__ slot_mapping, const int num_tokens, - const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x) { - const int block_elem_num = num_heads * head_size * block_size; - -#pragma omp parallel for collapse(2) - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; - if (slot_idx >= 0) { - int src_key_head_idx = token_idx * key_stride + head_idx * head_size; - int src_value_head_idx = - token_idx * value_stride + head_idx * head_size; - const scalar_t* src_key_head_ptr = key + src_key_head_idx; - const scalar_t* src_value_head_ptr = value + src_value_head_idx; - const int64_t block_index = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - scalar_t* target_key_head_ptr = key_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - scalar_t* target_value_head_ptr = value_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - - for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { - const int64_t target_offset = - src_key_idx * block_size + block_offset * x; - for (int i = 0; i < x; ++i) { - target_key_head_ptr[target_offset + i] = - src_key_head_ptr[src_key_idx + i]; - } - } - - for (int src_value_idx = 0; src_value_idx < head_size; - ++src_value_idx) { - const int64_t target_offset = - src_value_idx * block_size + block_offset; - target_value_head_ptr[target_offset] = - src_value_head_ptr[src_value_idx]; - } - } - } - } -} -}; // namespace - -template -void concat_and_cache_mla_cpu_impl( - const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] - const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] - scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank - // + pe_dim)] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int num_tokens, // - const int block_stride, // - const int entry_stride, // - const int kv_c_stride, // - const int k_pe_stride, // - const int kv_lora_rank, // - const int pe_dim, // - const int block_size // -) { -#pragma omp parallel for - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; - // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0) { - continue; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - - auto copy = [&](const scalar_t* __restrict__ src, - scalar_t* __restrict__ dst, int src_stride, int dst_stride, - int size, int offset) { - for (int i = 0; i < size; i++) { - const int64_t src_idx = token_idx * src_stride + i; - const int64_t dst_idx = - block_idx * block_stride + block_offset * entry_stride + i + offset; - dst[dst_idx] = src[src_idx]; - } - }; - - copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); - copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); - } -} - -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping) { - unsigned num_layers = key_caches.size(); - TORCH_CHECK(num_layers == value_caches.size()); - if (num_layers == 0) { - return; - } - - const int element_num_per_block = key_caches[0][0].numel(); - DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, - element_num_per_block, num_layers); - CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) - }); -} - -void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - torch::Tensor& k_scale, torch::Tensor& v_scale) { - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) - reshape_and_cache_cpu_impl( - key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), num_tokens, key_stride, value_stride, - num_heads, head_size, block_size, x); - CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) - }); -} - -void concat_and_cache_mla( - torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] - torch::Tensor& k_pe, // [num_tokens, pe_dim] - torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + - // pe_dim)] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& scale) { - int num_tokens = slot_mapping.size(0); - int kv_lora_rank = kv_c.size(1); - int pe_dim = k_pe.size(1); - int block_size = kv_cache.size(1); - - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); - TORCH_CHECK(kv_cache_dtype != "fp8"); - - int kv_c_stride = kv_c.stride(0); - int k_pe_stride = k_pe.stride(0); - int block_stride = kv_cache.stride(0); - int entry_stride = kv_cache.stride(1); - - VLLM_DISPATCH_FLOATING_TYPES( - kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl) - concat_and_cache_mla_cpu_impl( - kv_c.data_ptr(), k_pe.data_ptr(), - kv_cache.data_ptr(), slot_mapping.data_ptr(), - num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride, - kv_lora_rank, pe_dim, block_size); - CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl) - }); -} - -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { - TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") -} diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp new file mode 100644 index 0000000000000..50f17c758c148 --- /dev/null +++ b/csrc/cpu/cpu_attn.cpp @@ -0,0 +1,249 @@ +#include "cpu_attn_vec.hpp" +#include "cpu_attn_vec16.hpp" + +#ifdef CPU_CAPABILITY_AMXBF16 + #include "cpu_attn_amx.hpp" + #define AMX_DISPATCH(...) \ + case cpu_attention::ISA::AMX: { \ + using attn_impl = cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } +#else + #define AMX_DISPATCH(...) case cpu_attention::ISA::AMX: +#endif + +#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \ + case HEAD_DIM: { \ + constexpr size_t head_dim = HEAD_DIM; \ + return __VA_ARGS__(); \ + } + +#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \ + [&] { \ + switch (HEAD_DIM) { \ + CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \ + default: { \ + TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \ + std::to_string(HEAD_DIM)); \ + } \ + } \ + }() + +#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \ + [&] { \ + switch (ISA_TYPE) { \ + AMX_DISPATCH(__VA_ARGS__) \ + case cpu_attention::ISA::VEC: { \ + using attn_impl = \ + cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } \ + case cpu_attention::ISA::VEC16: { \ + using attn_impl = \ + cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } \ + default: { \ + TORCH_CHECK(false, "Invalid CPU attention ISA type."); \ + } \ + } \ + }() + +torch::Tensor get_scheduler_metadata( + const int64_t num_req, const int64_t num_heads_q, + const int64_t num_heads_kv, const int64_t head_dim, + const torch::Tensor& seq_lens, at::ScalarType dtype, + const torch::Tensor& query_start_loc, const bool casual, + const int64_t window_size, const std::string& isa_hint, + const bool enable_kv_split) { + cpu_attention::ISA isa; + if (isa_hint == "amx") { + isa = cpu_attention::ISA::AMX; + } else if (isa_hint == "vec") { + isa = cpu_attention::ISA::VEC; + } else if (isa_hint == "vec16") { + isa = cpu_attention::ISA::VEC16; + } else { + TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint); + } + + cpu_attention::AttentionScheduler::ScheduleInput input; + input.num_reqs = num_req; + input.num_heads_q = num_heads_q; + input.num_heads_kv = num_heads_kv; + input.head_dim = head_dim; + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + if (window_size != -1) { + input.left_sliding_window_size = window_size - 1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = window_size - 1; + } + } else { + input.left_sliding_window_size = -1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = -1; + } + } + input.casual = casual; + input.isa = isa; + input.enable_kv_split = enable_kv_split; + TORCH_CHECK(casual, "Only supports casual mask for now."); + + VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { + CPU_ATTN_DISPATCH_IMPL(isa, [&]() { + input.elem_size = sizeof(scalar_t); + input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); + input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); + input.output_buffer_elem_size = + sizeof(attn_impl::partial_output_buffer_t); + input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration; + input.kv_block_alignment = attn_impl::BlockSizeAlignment; + }); + }); + }); + + cpu_attention::AttentionScheduler scheduler; + torch::Tensor metadata = scheduler.schedule(input); + return metadata; +} + +void cpu_attn_reshape_and_cache( + const torch::Tensor& key, // [token_num, head_num, head_size] + const torch::Tensor& value, // [token_num, head_num, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& slot_mapping, const std::string& isa) { + TORCH_CHECK_EQ(key.dim(), 3); + TORCH_CHECK_EQ(value.dim(), 3); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + TORCH_CHECK_EQ(key.stride(2), 1); + TORCH_CHECK_EQ(value.stride(2), 1); + + const int64_t token_num = key.size(0); + const int64_t key_token_num_stride = key.stride(0); + const int64_t value_token_num_stride = value.stride(0); + const int64_t head_num = value.size(1); + const int64_t key_head_num_stride = key.stride(1); + const int64_t value_head_num_stride = value.stride(1); + const int64_t num_blocks = key_cache.size(0); + const int64_t num_blocks_stride = key_cache.stride(0); + const int64_t cache_head_num_stride = key_cache.stride(1); + const int64_t block_size = key_cache.size(2); + const int64_t block_size_stride = key_cache.stride(2); + const int64_t head_dim = key.size(-1); + + cpu_attention::ISA isa_tag = [&]() { + if (isa == "amx") { + return cpu_attention::ISA::AMX; + } else if (isa == "vec") { + return cpu_attention::ISA::VEC; + } else if (isa == "vec16") { + return cpu_attention::ISA::VEC16; + } else { + TORCH_CHECK(false, "Invalid ISA type: " + isa); + } + }(); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { + CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() { + attn_impl::reshape_and_cache( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), token_num, + key_token_num_stride, value_token_num_stride, head_num, + key_head_num_stride, value_head_num_stride, num_blocks, + num_blocks_stride, cache_head_num_stride, block_size, + block_size_stride); + }); + }); + }); +} + +void cpu_attention_with_kv_cache( + const torch::Tensor& query, // [num_tokens, num_heads, head_size] + const torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& output, // [num_tokens, num_heads, head_size] + const torch::Tensor& query_start_loc, // [num_tokens + 1] + const torch::Tensor& seq_lens, // [num_tokens] + const double scale, const bool causal, + const std::optional& alibi_slopes, // [num_heads] + const int64_t sliding_window_left, const int64_t sliding_window_right, + const torch::Tensor& block_table, // [num_tokens, max_block_num] + const double softcap, const torch::Tensor& scheduler_metadata, + const std::optional& s_aux // [num_heads] +) { + TORCH_CHECK_EQ(query.dim(), 3); + TORCH_CHECK_EQ(query.stride(2), 1); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + + cpu_attention::AttentionInput input; + input.metadata = reinterpret_cast( + scheduler_metadata.data_ptr()); + input.num_tokens = query.size(0); + input.num_heads = query.size(1); + input.num_kv_heads = key_cache.size(1); + input.block_size = key_cache.size(2); + input.query = query.data_ptr(); + input.query_num_tokens_stride = query.stride(0); + input.query_num_heads_stride = query.stride(1); + input.cache_num_blocks_stride = key_cache.stride(0); + input.cache_num_kv_heads_stride = key_cache.stride(1); + input.blt_num_tokens_stride = block_table.stride(0); + input.key_cache = key_cache.data_ptr(); + input.value_cache = value_cache.data_ptr(); + input.output = output.data_ptr(); + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + input.block_table = block_table.data_ptr(); + input.alibi_slopes = + alibi_slopes.has_value() ? alibi_slopes->data_ptr() : nullptr; + // For now sink must be bf16 + input.s_aux = s_aux.has_value() ? s_aux->data_ptr() : nullptr; + input.scale = scale; + input.causal = causal; + input.sliding_window_left = sliding_window_left; + input.sliding_window_right = sliding_window_right; + if (input.causal) { + // to make boundary calculation easier + input.sliding_window_right = 0; + } + float softcap_fp32 = softcap; + input.softcap = softcap_fp32; + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "cpu_attention_with_kv_cache", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] { + CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() { + TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); + cpu_attention::AttentionMainLoop mainloop; + mainloop(&input); + }); + }); + }); +} diff --git a/csrc/cpu/cpu_attn_amx.hpp b/csrc/cpu/cpu_attn_amx.hpp new file mode 100644 index 0000000000000..8da458b99119c --- /dev/null +++ b/csrc/cpu/cpu_attn_amx.hpp @@ -0,0 +1,511 @@ +#ifndef CPU_ATTN_AMX_HPP +#define CPU_ATTN_AMX_HPP + +#include "cpu_attn_impl.hpp" + +namespace cpu_attention { +namespace { +// AMX specific +constexpr static int64_t AMX_TILE_ROW_BYTES = 64; +constexpr static int64_t AMX_TILE_ROW_NUM = 16; +constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM; + +typedef struct __tile_config { + uint8_t palette_id = 1; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +} __tilecfg; + +// 2-2-4 pattern, for 16 < m <= 32 +// TILE 0, 1: load A matrix, row num should be 16, m - 16 +// TILE 2, 3: load B matrix, row num should be 16 +// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m +// - 16 +template +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } +}; + +template <> +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM; + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + // k_cache, v_cache are prepacked + const int32_t b_tile_stride = AMX_TILE_ROW_BYTES; + + // logits_buffer, output_buffer are not prepacked + float* __restrict__ c_tile_4 = c_tile; + float* __restrict__ c_tile_5 = + c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float); + float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc; + float* __restrict__ c_tile_7 = + c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float); + const int32_t c_tile_stride = ldc * sizeof(float); + + if (accum_c) { + _tile_loadd(4, c_tile_4, c_tile_stride); + _tile_loadd(5, c_tile_5, c_tile_stride); + _tile_loadd(6, c_tile_6, c_tile_stride); + _tile_loadd(7, c_tile_7, c_tile_stride); + } else { + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_tile_stride); + _tile_dpbf16ps(4, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_tile_stride); + _tile_dpbf16ps(5, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + _tile_stored(4, c_tile_4, c_tile_stride); + _tile_stored(5, c_tile_5, c_tile_stride); + _tile_stored(6, c_tile_6, c_tile_stride); + _tile_stored(7, c_tile_7, c_tile_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + const int32_t m_0 = AMX_TILE_ROW_NUM; + const int32_t m_1 = m - AMX_TILE_ROW_NUM; + config.rows[0] = m_0; + config.rows[1] = m_1; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = m_0; + config.rows[5] = m_0; + config.rows[6] = m_1; + config.rows[7] = m_1; + _tile_loadconfig(&config); + } +}; + +// 1-2-2 pattern, for 0 < m <= 16 +// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be +// m, m +// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row +// num should be 16 +// TILE 6, 7, (6, 7): store results C matrix, row num should be +// m +template +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } +}; + +template <> +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + c10::BFloat16* __restrict__ b_tile_4 = + b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + c10::BFloat16* __restrict__ b_tile_5 = + b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + int64_t b_stride = AMX_TILE_ROW_BYTES; + + float* __restrict__ c_tile_6 = c_tile; + float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float); + int64_t c_stride = ldc * sizeof(float); + + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + const int32_t k_group_times = k_times / 2; + const bool has_tail = (k_times % 2 == 1); + + if (accum_c) { + _tile_loadd(6, c_tile_6, c_stride); + _tile_loadd(7, c_tile_7, c_stride); + } else { + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_group_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_stream_loadd(4, b_tile_4, b_stride); + _tile_dpbf16ps(6, 1, 4); + _tile_stream_loadd(5, b_tile_5, b_stride); + _tile_dpbf16ps(7, 1, 5); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } + b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + if (has_tail) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + } + + _tile_stored(6, c_tile_6, c_stride); + _tile_stored(7, c_tile_7, c_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + config.rows[0] = m; + config.rows[1] = m; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = AMX_TILE_ROW_NUM; + config.rows[5] = AMX_TILE_ROW_NUM; + config.rows[6] = m; + config.rows[7] = m; + _tile_loadconfig(&config); + } +}; +} // namespace + +template +class AttentionImpl { + public: + using query_t = scalar_t; + using q_buffer_t = scalar_t; + using kv_cache_t = scalar_t; + using logits_buffer_t = float; + using partial_output_buffer_t = float; + using prob_buffer_t = scalar_t; + + constexpr static int64_t BlockSizeAlignment = + AMX_TILE_ROW_BYTES / + sizeof(kv_cache_t); // KV token num unit of QK and PV phases + constexpr static int64_t HeadDimAlignment = + 2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase + constexpr static int64_t MaxQHeadNumPerIteration = 32; + constexpr static int64_t HeadDim = head_dim; + constexpr static ISA ISAType = ISA::AMX; + constexpr static bool scale_on_logits = true; + + public: + AttentionImpl() : current_q_head_num_(0) { + // Use all columns in AMX tiles + vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; }); + } + + ~AttentionImpl() { _tile_release(); } + + template