vllm/vllm/v1/attention/backends/linear_attn.py
Jiangyun Zhu 81db702ed2
[Attention] add _cudagraph_support for linear attention (#28934)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
2025-11-25 12:25:20 +08:00

78 lines
2.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class LinearAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
return LinearAttentionMetadataBuilder
@dataclass
class LinearAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> LinearAttentionMetadata:
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
attn_metadata = LinearAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata