From 7721ef1786c49ec3738db5e61821183ee969d2a2 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Tue, 8 Jul 2025 13:13:44 +0800 Subject: [PATCH] [CI/Build][CPU] Fix CPU CI and remove all CPU V0 files (#20560) Signed-off-by: jiang1.li --- .../scripts/hardware_ci/run-cpu-test.sh | 24 +- .../basic_correctness/test_chunked_prefill.py | 58 -- .../models/language/generation/test_common.py | 8 +- .../models/language/pooling/test_embedding.py | 23 +- tests/models/language/pooling/test_reward.py | 5 + tests/quantization/test_compressed_tensors.py | 3 +- vllm/attention/backends/torch_sdpa.py | 546 ------------- vllm/attention/ops/ipex_attn.py | 195 ----- vllm/v1/attention/backends/cpu_attn.py | 762 +++++++++++++++++- 9 files changed, 785 insertions(+), 839 deletions(-) delete mode 100644 vllm/attention/backends/torch_sdpa.py delete mode 100644 vllm/attention/ops/ipex_attn.py diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 737b2eede9c6..afe3e4b7ef69 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -48,10 +48,16 @@ function cpu_tests() { # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model - pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model - pytest -v -s tests/models/language/generation -m cpu_model - VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model + # Note: disable until supports V1 + # pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model + # pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model + + # Note: disable Bart until supports V1 + pytest -v -s tests/models/language/generation -m cpu_model \ + --ignore=tests/models/language/generation/test_bart.py + VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \ + --ignore=tests/models/language/generation/test_bart.py + pytest -v -s tests/models/language/pooling -m cpu_model pytest -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ @@ -62,21 +68,15 @@ function cpu_tests() { docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -s -v \ - tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ - tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]" + # Note: disable it until supports V1 # Run AWQ test # docker exec cpu-test-"$NUMA_NODE" bash -c " # set -e # VLLM_USE_V1=0 pytest -s -v \ # tests/quantization/test_ipex_quant.py" - # Run chunked-prefill and prefix-cache test - docker exec cpu-test-"$NUMA_NODE" bash -c " - set -e - pytest -s -v -k cpu_model \ - tests/basic_correctness/test_chunked_prefill.py" - # online serving docker exec cpu-test-"$NUMA_NODE" bash -c " set -e diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 4a422e8555da..4816b76996fc 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -294,61 +294,3 @@ def test_with_prefix_caching( name_0="w/o prefix caching", name_1="with prefix caching", ) - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -@pytest.mark.parametrize("enforce_eager", [False]) -@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) -@pytest.mark.cpu_model -@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") -def test_models_cpu( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, - enforce_eager: bool, - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - test_models( - hf_runner, - vllm_runner, - example_prompts, - model, - dtype, - max_tokens, - chunked_prefill_token_size, - enforce_eager, - 1, - attention_backend, - monkeypatch, - ) - - -@pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("enforce_eager", [False]) -@pytest.mark.parametrize("chunk_size", [30, 32]) -@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) -@pytest.mark.cpu_model -@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") -def test_with_prefix_caching_cpu( - vllm_runner: VllmRunner, - max_tokens: int, - enforce_eager: bool, - chunk_size: int, - dtype: str, -) -> None: - test_with_prefix_caching( - vllm_runner, - max_tokens, - enforce_eager, - chunk_size, - 1, - dtype, - ) diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 7d7a62eec118..8aba68829b10 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -39,7 +39,7 @@ AITER_MODEL_LIST = [ [ pytest.param( "bigscience/bloom-560m", # bloom - testing alibi slopes - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[pytest.mark.core_model], ), pytest.param( "openai-community/gpt2", # gpt2 @@ -87,7 +87,11 @@ AITER_MODEL_LIST = [ pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( "TitanML/tiny-mixtral", # mixtral - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[pytest.mark.core_model], + ), + pytest.param( + "Qwen/Qwen1.5-MoE-A2.7B-Chat", + marks=[pytest.mark.cpu_model], ) ]) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 05fcf4101ff9..cc9e4102d5b7 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os from typing import Optional import pytest @@ -29,8 +28,10 @@ def v1(run_with_both_engines): # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", marks=[pytest.mark.core_model]), - pytest.param("intfloat/e5-mistral-7b-instruct", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param( + "intfloat/e5-mistral-7b-instruct", + # CPU v1 doesn't support sliding window + marks=[pytest.mark.core_model]), # the qwen models interfere with each other (see PR # https://github.com/vllm-project/vllm/pull/18720). # To avoid this problem, for now we skip v0 since it will be @@ -38,11 +39,13 @@ def v1(run_with_both_engines): pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), # [Encoder-only] - pytest.param("BAAI/bge-base-en-v1.5", - marks=[ - pytest.mark.core_model, pytest.mark.cpu_model, - pytest.mark.skip_v1 - ]), + pytest.param( + "BAAI/bge-base-en-v1.5", + marks=[ + # CPU only supports V1 + pytest.mark.core_model, + pytest.mark.skip_v1 + ]), pytest.param("sentence-transformers/all-MiniLM-L12-v2", marks=[pytest.mark.skip_v1]), pytest.param("intfloat/multilingual-e5-small", @@ -61,10 +64,6 @@ def test_models( model, monkeypatch, ) -> None: - if model == "intfloat/e5-mistral-7b-instruct" and current_platform.is_cpu( - ) and os.environ.get("VLLM_USE_V1", "0") == "1": - pytest.skip("CPU V1 doesn't support sliding window") - if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): # ROCm Triton FA does not currently support sliding window attention # switch to use ROCm CK FA backend diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index ec3d25ee22a9..3b7fab3ba5c9 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + import pytest import torch import torch.nn.functional as F @@ -84,6 +86,9 @@ def test_prm_models( dtype: str, monkeypatch, ) -> None: + if current_platform.is_cpu() and os.environ.get("VLLM_USE_V1", "0") == "0": + pytest.skip("CPU only supports V1") + if current_platform.is_rocm(): # ROCm Triton FA does not currently support sliding window attention # switch to use ROCm CK FA backend diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 3646ad6c481b..db7e50eff72b 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -45,7 +45,8 @@ def use_v0_only(monkeypatch): """ This module relies on V0 internals, so set VLLM_USE_V1=0. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + if not current_platform.is_cpu(): + monkeypatch.setenv('VLLM_USE_V1', '0') @pytest.mark.parametrize( diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py deleted file mode 100644 index a490aa397991..000000000000 --- a/vllm/attention/backends/torch_sdpa.py +++ /dev/null @@ -1,546 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" Attention layer with torch scaled_dot_product_attention - and PagedAttention.""" -from dataclasses import dataclass -from typing import Any, Dict, List, Optional - -import torch -from torch.nn.functional import scaled_dot_product_attention - -# yapf conflicts with isort for this block -# yapf: disable -from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) -# yapf: enable -from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex -from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -@dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for TorchSDPABackend. - """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - chunked_prefill: bool - seq_lens: Optional[List[int]] = None # For non-chunked prefill - - # For chunked prefill only - max_query_len: Optional[int] = None - max_kv_len: Optional[int] = None - prefill_query_start_loc: Optional[torch.Tensor] = None - kv_start_loc: Optional[torch.Tensor] = None - prefill_block_tables: Optional[torch.Tensor] = None - - # For V1 logits index only - query_start_loc: Optional[torch.Tensor] = None - - # Begin encoder attn & enc/dec cross-attn fields... - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[List[torch.Tensor]] = None - self.encoder_attn_bias: Optional[List[torch.Tensor]] = None - self.cross_attn_bias: Optional[List[torch.Tensor]] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) - - @property - def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: - if self.num_prefill_tokens == 0: - return None - return self - - @property - def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: - if self.num_decode_tokens == 0: - return None - return self - - def get_seq_lens( - self, - attn_type: str, - ): - ''' - Extract appropriate sequence lengths from attention metadata - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - * Appropriate sequence lengths tensor for query - * Appropriate sequence lengths tensor for key & value - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - seq_lens_q = self.seq_lens - seq_lens_kv = self.seq_lens - elif attn_type == AttentionType.ENCODER: - seq_lens_q = self.encoder_seq_lens - seq_lens_kv = self.encoder_seq_lens - elif attn_type == AttentionType.ENCODER_DECODER: - seq_lens_q = self.seq_lens - seq_lens_kv = self.encoder_seq_lens - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - return seq_lens_q, seq_lens_kv - - def get_attn_bias( - self, - attn_type: str, - ) -> Optional[List[torch.Tensor]]: - ''' - Extract appropriate attention bias from attention metadata - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - * Appropriate attention bias value given the attention type - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - return self.attn_bias - elif attn_type == AttentionType.ENCODER: - return self.encoder_attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - return self.cross_attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - def set_attn_bias( - self, - attn_bias: List[torch.Tensor], - attn_type: str, - ) -> None: - ''' - Update appropriate attention bias field of attention metadata, - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_bias: The desired attention bias value - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - self.attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER: - self.encoder_attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - self.cross_attn_bias = attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - def get_seq_len_block_table_args( - self, - attn_type: str, - ) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - return (self.seq_lens_tensor, self.max_decode_seq_len, - self.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - self.cross_block_tables) - elif attn_type == AttentionType.ENCODER: - # No block tables associated with encoder attention - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - None) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - if blocksparse_params is not None: - raise ValueError( - "Torch SPDA does not support block-sparse attention.") - if logits_soft_cap is not None: - logger.warning_once("Torch SPDA does not support logits soft cap. " - "Outputs may be slightly off.") - if use_irope: - logger.warning_once( - "Using irope in Torch SPDA is not supported yet, it will fall" - " back to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = sliding_window - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) - - supported_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: - raise NotImplementedError( - "Torch SDPA backend FP8 KV cache requires " - "intel_extension_for_pytorch support.") - self.attn_type = attn_type - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: TorchSDPAMetadata, # type: ignore - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with torch SDPA and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for TorchSDPABackendImpl") - - # For warming-up - if attn_metadata is None: - return query - - attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None - - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): - # KV-cache during decoder-self- or - # encoder-decoder-cross-attention, but not - # during encoder attention. - # - # Even if there are no new key/value pairs to cache, - # we still need to break out key_cache and value_cache - # i.e. for later use by paged attention - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - if (key is not None) and (value is not None): - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - # During cross-attention decode, key & value will be None, - # preventing this IF-statement branch from running - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - PagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, updated_slot_mapping, - self.kv_cache_dtype, layer._k_scale, layer._v_scale) - - if attn_type != AttentionType.ENCODER: - # Decoder self-attention supports chunked prefill. - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - else: - # Encoder attention - chunked prefill is not applicable; - # derive token-count from query shape & and treat them - # as 100% prefill tokens - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 - - if attn_type == AttentionType.DECODER: - # Only enforce this shape-constraint for decoder - # self-attention - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) - if prefill_meta := attn_metadata.prefill_metadata: - if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore - assert attn_metadata.seq_lens is not None - self._run_sdpa_forward(output, - query, - key, - value, - prefill_meta, - attn_type=attn_type) - else: - # prefix-enabled attention - assert not self.need_mask - import intel_extension_for_pytorch.llm.modules as ipex_modules - output = torch.empty_like(query) - ipex_modules.PagedAttention.flash_attn_varlen_func( - output[:prefill_meta.num_prefill_tokens, :, :], - query[:prefill_meta.num_prefill_tokens, :, :], - key_cache, - value_cache, - prefill_meta.prefill_query_start_loc, - prefill_meta.kv_start_loc, - prefill_meta.max_query_len, - prefill_meta.max_kv_len, - self.scale, - True, - prefill_meta.prefill_block_tables, - self.alibi_slopes, - ) - - if decode_meta := attn_metadata.decode_metadata: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") - # Decoding run. - ( - seq_lens_arg, - max_seq_len_arg, - block_tables_arg, - ) = decode_meta.get_seq_len_block_table_args(attn_type) - - PagedAttention.forward_decode( - output[attn_metadata.num_prefill_tokens:, :, :], - query[attn_metadata.num_prefill_tokens:, :, :], - key_cache, - value_cache, - block_tables_arg, - seq_lens_arg, - max_seq_len_arg, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - def _run_sdpa_forward( - self, - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: TorchSDPAMetadata, - attn_type: str = AttentionType.DECODER, - ) -> None: - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - - attn_masks = attn_metadata.get_attn_bias(attn_type) - if attn_masks is None: - if self.alibi_slopes is not None: - attn_masks = _make_alibi_bias( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore - elif self.sliding_window is not None: - assert attn_metadata.seq_lens is not None - attn_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, - query.dtype) # type: ignore - else: - seq_lens, _ = attn_metadata.get_seq_lens(attn_type) - attn_masks = [None] * len(seq_lens) - attn_metadata.set_attn_bias(attn_masks, attn_type) - - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - - causal_attn = (attn_type == AttentionType.DECODER) - - seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) - start_q, start_kv = 0, 0 - for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, - attn_masks): - end_q = start_q + seq_len_q - end_kv = start_kv + seq_len_kv - sub_out = scaled_dot_product_attention( - query[None, :, start_q:end_q, :], - key[None, :, start_kv:end_kv, :], - value[None, :, start_kv:end_kv, :], - attn_mask=mask, - dropout_p=0.0, - is_causal=causal_attn and mask is None, - scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) - output[start_q:end_q, :, :] = sub_out - start_q, start_kv = end_q, end_kv - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: List[int], -) -> List[torch.Tensor]: - attn_biases: List[torch.Tensor] = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat((num_heads, 1, 1)) - bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) - attn_biases.append((bias + inf_mask).to(dtype)) - - return attn_biases - - -def _make_sliding_window_bias( - seq_lens: List[int], - window_size: Optional[int], - dtype: torch.dtype, -) -> List[torch.Tensor]: - attn_biases: List[torch.Tensor] = [] - for seq_len in seq_lens: - tensor = torch.full( - (1, seq_len, seq_len), - dtype=dtype, - fill_value=1, - ) - shift = 0 - mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore - if window_size is not None: - mask = torch.triu(mask, diagonal=shift - window_size + 1) - mask = torch.log(mask) - attn_biases.append(mask.to(dtype)) - - return attn_biases diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py deleted file mode 100644 index 891975498916..000000000000 --- a/vllm/attention/ops/ipex_attn.py +++ /dev/null @@ -1,195 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Tuple - -try: - import intel_extension_for_pytorch.llm.modules as ipex_modules - _use_ipex = True -# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813 -except (ImportError, AttributeError): - _use_ipex = False - -import torch - -from vllm import _custom_ops as ops - - -class _PagedAttention: - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 80, 96, 112, 128, 192, 256] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - *args, - ) -> Tuple[int, ...]: - return 2, num_blocks, block_size * num_kv_heads * head_size - - @staticmethod - def split_kv_cache( - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, - *args, - ) -> Tuple[torch.Tensor, torch.Tensor]: - x = 16 // kv_cache.element_size() - num_blocks = kv_cache.shape[1] - - key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, - -1, x) - value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) - return key_cache, value_cache - - @staticmethod - def write_to_paged_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - *args, - ) -> None: - ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) - - @staticmethod - def forward_decode( - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: torch.Tensor, - v_scale: torch.Tensor, - *args, - ) -> None: - tp_rank: int = 0 - blocksparse_local_blocks: int = 0 - blocksparse_vert_stride: int = 0 - blocksparse_block_size: int = 64 - blocksparse_head_sliding_step: int = 0 - block_size = value_cache.shape[3] - - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - *args, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -class _IPEXPagedAttention(_PagedAttention): - - @staticmethod - def split_kv_cache( - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, - *args, - ) -> Tuple[torch.Tensor, torch.Tensor]: - num_blocks = kv_cache.shape[1] - - key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) - value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) - return key_cache, value_cache - - @staticmethod - def write_to_paged_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - *args, - ) -> None: - ipex_modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, - slot_mapping.flatten().int()) - - @staticmethod - def forward_decode( - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: torch.Tensor, - v_scale: torch.Tensor, - *args, - ) -> None: - block_size = value_cache.shape[2] - head_mapping = torch.arange( - 0, - num_kv_heads, - device="cpu", - dtype=torch.int32, - ).view(num_kv_heads, - 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() - ipex_modules.PagedAttention.single_query_cached_kv_attention( - output, query.contiguous(), key_cache, value_cache, head_mapping, - scale, block_tables, context_lens, block_size, max_context_len, - alibi_slopes) - - -PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 37c04c7a029e..08e802958b69 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Any, Optional + import numpy as np import torch +from torch.nn.functional import scaled_dot_product_attention -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) -from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl, - TorchSDPAMetadata) +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState -from vllm.attention.ops.ipex_attn import PagedAttention +from vllm.logger import init_logger from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput @@ -17,18 +21,28 @@ from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.gpu_input_batch import InputBatch +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +# AttributeError is to handle a bug in ipex +# https://github.com/intel/intel-extension-for-pytorch/pull/813 +except (ImportError, AttributeError): + _use_ipex = False + +from vllm import _custom_ops as ops + +logger = init_logger(__name__) + class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return PagedAttention.get_supported_head_sizes() - @classmethod def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: + attn_impl = _get_paged_attn_impl() + is_valid, supported_head_sizes = attn_impl.validate_head_size( + head_size) + if not is_valid: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " @@ -63,14 +77,239 @@ class TorchSDPABackend(AttentionBackend): num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + return _get_paged_attn_impl().get_kv_cache_shape( + num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False +@dataclass +class TorchSDPAMetadata(AttentionMetadata): + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + chunked_prefill: bool + seq_lens: Optional[list[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + prefill_query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None + + # For V1 logits index only + query_start_loc: Optional[torch.Tensor] = None + + # Begin encoder attn & enc/dec cross-attn fields... + # Encoder sequence lengths representation + encoder_seq_lens: Optional[list[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[list[torch.Tensor]] = None + self.encoder_attn_bias: Optional[list[torch.Tensor]] = None + self.cross_attn_bias: Optional[list[torch.Tensor]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_prefill_tokens == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def get_seq_lens( + self, + attn_type: str, + ): + ''' + Extract appropriate sequence lengths from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate sequence lengths tensor for query + * Appropriate sequence lengths tensor for key & value + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + seq_lens_q = self.seq_lens + seq_lens_kv = self.seq_lens + elif attn_type == AttentionType.ENCODER: + seq_lens_q = self.encoder_seq_lens + seq_lens_kv = self.encoder_seq_lens + elif attn_type == AttentionType.ENCODER_DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.encoder_seq_lens + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + return seq_lens_q, seq_lens_kv + + def get_attn_bias( + self, + attn_type: str, + ) -> Optional[list[torch.Tensor]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return self.attn_bias + elif attn_type == AttentionType.ENCODER: + return self.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return self.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def set_attn_bias( + self, + attn_bias: list[torch.Tensor], + attn_type: str, + ) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + self.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + self.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + self.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( + self, + attn_type: str, + ) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + return (self.seq_lens_tensor, self.max_decode_seq_len, + self.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + self.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec, @@ -182,3 +421,500 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): ) return attn_metadata + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off.") + if use_irope: + logger.warning_once( + "Using irope in Torch SPDA is not supported yet, it will fall" + " back to global attention for long context.") + self.paged_attn_impl = _get_paged_attn_impl() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: + raise NotImplementedError( + "Torch SDPA backend FP8 KV cache requires " + "intel_extension_for_pytorch support.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TorchSDPAMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TorchSDPABackendImpl") + + # For warming-up + if attn_metadata is None: + return query + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = self.paged_attn_impl.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + self.paged_attn_impl.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + if prefill_meta := attn_metadata.prefill_metadata: + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + assert attn_metadata.seq_lens is not None + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) + else: + # prefix-enabled attention + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.prefill_query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + # Decoding run. + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = decode_meta.get_seq_len_block_table_args(attn_type) + + self.paged_attn_impl.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_sdpa_forward( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: TorchSDPAMetadata, + attn_type: str = AttentionType.DECODER, + ) -> None: + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + attn_masks = attn_metadata.get_attn_bias(attn_type) + if attn_masks is None: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + assert attn_metadata.seq_lens is not None + attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + seq_lens, _ = attn_metadata.get_seq_lens(attn_type) + attn_masks = [None] * len(seq_lens) + attn_metadata.set_attn_bias(attn_masks, attn_type) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + + seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, + attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: list[int], +) -> list[torch.Tensor]: + attn_biases: list[torch.Tensor] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: list[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> list[torch.Tensor]: + attn_biases: list[torch.Tensor] = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases + + +class _PagedAttention: + + @staticmethod + def validate_head_size(head_size: int) -> tuple[bool, list[int]]: + SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] + return head_size in SUPPORT_HS, SUPPORT_HS + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> tuple[int, ...]: + return 2, num_blocks, block_size * num_kv_heads * head_size + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: list[torch.Tensor], + src_to_dists: torch.Tensor, + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + + @staticmethod + def validate_head_size(head_size: int) -> tuple[bool, list[int]]: + return True, [] + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + block_size = value_cache.shape[2] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + + +def _get_paged_attn_impl(): + if _use_ipex: + return _IPEXPagedAttention + else: + return _PagedAttention