diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index eaa0fa1d5db39..8461ed61480b7 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -42,6 +42,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + FLASH_ATTN_DIFFKV = ( + "vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend" + ) TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 43bfa4f8324cc..662ecef3ac8f6 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -79,9 +79,10 @@ from vllm.model_executor.models.utils import ( sequence_parallel_chunk, ) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.flash_attn_diffkv import FlashAttentionDiffKVBackend def check_ffn_act_fn(act_fn: str): @@ -645,6 +646,7 @@ class OpenPanguSinkAttention(nn.Module): else: sliding_window = None + FlashAttentionDiffKVBackend.set_head_size_v(self.v_channels) self.attn = StaticSinkAttention( self.num_heads, self.head_dim, @@ -656,7 +658,7 @@ class OpenPanguSinkAttention(nn.Module): per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", - attn_backend=FlashAttentionBackend, + attn_backend=FlashAttentionDiffKVBackend, head_size_v=self.v_channels, ) @@ -668,7 +670,7 @@ class OpenPanguSinkAttention(nn.Module): self.num_kv_heads, self.head_dim, ), - device=torch.cuda.current_device(), + device=current_platform.current_device(), dtype=config.torch_dtype, ) ) @@ -688,7 +690,7 @@ class OpenPanguSinkAttention(nn.Module): self.num_kv_heads, self.v_channels, ), - device=torch.cuda.current_device(), + device=current_platform.current_device(), dtype=config.torch_dtype, ) ) @@ -706,9 +708,11 @@ class OpenPanguSinkAttention(nn.Module): self.num_kv_heads, self.v_channels, ), - device=torch.cuda.current_device(), + device=current_platform.current_device(), dtype=config.torch_dtype, ) + # To enable dummy run with out weight + self.post_weight_load() def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): output_dim = getattr(param, "output_dim", None) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8b030a04b438d..f5ad98cf2125c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -18,9 +18,6 @@ from vllm.attention.backends.abstract import ( from vllm.attention.layer import Attention from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash_diffkv, -) from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8, get_flash_attn_version, @@ -108,48 +105,28 @@ class FlashAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", - head_size_v: int | None = None, ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - if head_size_v is None or head_size == head_size_v: - return (2, num_blocks, block_size, num_kv_heads, head_size) - else: - return ( - num_blocks, - block_size, - num_kv_heads, - head_size + head_size_v, - ) + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def get_kv_cache_stride_order( include_num_layers_dimension: bool = False, - diff_kv: bool = False, ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() if cache_layout == "NHD" and include_num_layers_dimension: - if not diff_kv: - # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) - return (2, 0, 1, 3, 4, 5) - else: - # (num_blocks, num_layers, block_size, - # num_kv_heads, head_size + head_size_v) - return (0, 1, 2, 3, 4) + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (2, 0, 1, 3, 4, 5) elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3, 4) if not diff_kv else (0, 1, 2, 3) + stride_order = (0, 1, 2, 3, 4) elif cache_layout == "HND" and include_num_layers_dimension: - if not diff_kv: - # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) - return (2, 4, 0, 1, 3, 5) - else: - # (num_blocks, num_kv_heads, num_layers, - # block_size, head_size + head_size_v) - return (2, 3, 0, 1, 4) + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 4, 0, 1, 3, 5) elif cache_layout == "HND": - stride_order = (0, 1, 3, 2, 4) if not diff_kv else (0, 2, 1, 3) + stride_order = (0, 1, 3, 2, 4) else: raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order @@ -599,14 +576,11 @@ class FlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - or [num_tokens, num_kv_heads, head_size_v] kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] - or [num_blocks, block_size, num_kv_heads, head_size + head_size_v] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] - or [num_tokens, num_heads * head_size_v] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values @@ -649,13 +623,7 @@ class FlashAttentionImpl(AttentionImpl): ) # For decoder and cross-attention, use KV cache as before - if self.head_size == kv_cache.shape[-1]: - # Same head_size for K and V - key_cache, value_cache = kv_cache.unbind(0) - else: - # Different head_size for K and V - key_cache = kv_cache[..., : self.head_size] - value_cache = kv_cache[..., self.head_size :] + key_cache, value_cache = kv_cache.unbind(0) # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached @@ -672,29 +640,16 @@ class FlashAttentionImpl(AttentionImpl): # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - if self.head_size == kv_cache.shape[-1]: - # kv_cache update for same head_size K and V - reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - # kv_cache update for different head_size K and V - triton_reshape_and_cache_flash_diffkv( - key, - value, - kv_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py new file mode 100644 index 0000000000000..2e36740bd9e52 --- /dev/null +++ b/vllm/v1/attention/backends/flash_attn_diffkv.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) +from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available + +if is_flash_attn_varlen_func_available(): + from vllm.attention.utils.fa_utils import flash_attn_varlen_func +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import get_kv_cache_layout + +from .flash_attn import ( + FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + cascade_attention, +) + +logger = init_logger(__name__) + + +class FlashAttentionDiffKVBackend(FlashAttentionBackend): + # Default to 128 for this backend + head_size_v: int = 128 + + @classmethod + def set_head_size_v(cls, head_size_v: int) -> None: + cls.head_size_v = head_size_v + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_DIFFKV" + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionDiffKVImpl + + # Do not modify the interface of get_kv_cache_shape, + # but consider head_size_v when returning result. + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return ( + num_blocks, + block_size, + num_kv_heads, + head_size + FlashAttentionDiffKVBackend.head_size_v, + ) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, block_size, + # num_kv_heads, head_size + head_size_v) + return (1, 0, 2, 3, 4) + elif cache_layout == "NHD": + stride_order = (0, 1, 2, 3) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, + # block_size, head_size + head_size_v) + return (1, 3, 0, 2, 4) + elif cache_layout == "HND": + stride_order = (0, 2, 1, 3) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + + +class FlashAttentionDiffKVImpl(FlashAttentionImpl): + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size_v] + kv_cache: shape = + [num_blocks, block_size, num_kv_heads, head_size + head_size_v] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size_v] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for FlashAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + attn_type = self.attn_type + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) + + # For decoder and cross-attention, use KV cache as before + # Different head_size for K and V + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + + # kv_cache update for different head_size K and V + triton_reshape_and_cache_flash_diffkv( + key, + value, + kv_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer + dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( + self.kv_cache_dtype + ) + key_cache = key_cache.view(dtype) + value_cache = value_cache.view(dtype) + + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata + + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) + return output + + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + max_num_splits=attn_metadata.max_num_splits, + fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, + q_descale=layer._q_scale, + k_descale=layer._k_scale, + v_descale=layer._v_scale, + s_aux=self.sinks, + ) + return output