# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class LinearAttentionBackend(AttentionBackend): @staticmethod def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: return LinearAttentionMetadataBuilder @dataclass class LinearAttentionMetadata: num_prefills: int num_prefill_tokens: int num_decodes: int num_decode_tokens: int query_start_loc: torch.Tensor seq_lens: torch.Tensor state_indices_tensor: torch.Tensor # shape: [batch,] class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]): reorder_batch_threshold: int = 1 _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> LinearAttentionMetadata: query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold ) ) attn_metadata = LinearAttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, query_start_loc=query_start_loc, seq_lens=seq_lens, state_indices_tensor=state_indices_tensor, ) return attn_metadata