diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index eb00bc72b4b0d..5d5a26fbfc2cd 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -4,24 +4,21 @@ import itertools import random -import unittest from collections.abc import Sequence from numbers import Number from typing import Any, NamedTuple +from unittest.mock import patch import pytest import torch from torch._prims_common import TensorLikeType from tests.kernels.quant_utils import native_w8a8_block_matmul -from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType -from vllm.attention.backends.registry import _Backend +from vllm.attention import AttentionType from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils import ( STR_BACKEND_ENV_VAR, - STR_FLASH_ATTN_VAL, - STR_XFORMERS_ATTN_VAL, ) from vllm.utils.torch_utils import make_tensor_with_pad @@ -512,50 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs: ) -def make_backend(backend_name: str) -> AttentionBackend: - """ - Construct the backend instance determined by the backend_name string - argument. - - Note: at time of writing the Attention wrapper automatically selects - its own backend for Attention.forward(); so the backend instance which - you generate with this function is not meant to be used for *running* - inference, but rather for generating compatible metadata structures - using backend.make_metadata() - - - Returns: - - * Backend instance - """ - if backend_name == STR_XFORMERS_ATTN_VAL: - from vllm.v1.attention.backends.xformers import XFormersAttentionBackend - - return XFormersAttentionBackend() - if backend_name == STR_FLASH_ATTN_VAL: - from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend - - return FlashAttentionBackend() - if backend_name == "TRITON_ATTN": - from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend - - return TritonAttentionBackend() - if backend_name == "FLEX_ATTENTION": - from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend - - return FlexAttentionBackend() - if backend_name == "TORCH_SDPA": - from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend - - return TorchSDPABackend() - if backend_name == "FLASHINFER": - from vllm.v1.attention.backends.flashinfer import FlashInferBackend - - return FlashInferBackend() - - raise AssertionError(f"Unrecognized backend_name {backend_name} for unit test") - - def make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, @@ -877,197 +830,6 @@ def make_block_tables_slot_mapping( return (block_tables_tensor, slot_mapping_list, max_block_idx) -def make_test_metadata( - attn_backend: _Backend, - is_prompt: bool, - seq_lens: list[int] | None, - decoder_test_params: PhaseTestParameters | None, - device: torch.device | str, - encoder_test_params: PhaseTestParameters | None = None, - cross_test_params: PhaseTestParameters | None = None, -) -> AttentionMetadata: - """ - Construct fake attention metadata for a given test phase - (prefill-phase or decode-phase). - - encoder_test_params and cross_test_params arguments allow encoder - attention and enc/dec cross-attention (respectively) to use distinct - metadata values from decoder self-attention (decoder_test_params.) - - if encoder_test_params and cross_test_params are None, the attention - metadata will support decoder-only scenario. - - Assumptions: - - * No chunked prefill -> a batch is 100% prefill or 100% decode, never both - - Arguments: - - * attn_backend_name: Backend for sourcing attention kernels - * is_prompt: prefill if True, o/w decode - * seq_lens: list of token counts for each sequence - * decoder_test_params: decoder self-attention test params; - this function requires - kv_mmap (memory mapping) field - * device: CPU or CUDA device - * encoder_test_params: encoder attention test params; - this function requires encoder query - sequence lengths field. If None, - encoder query sequence lengths are - treated as None - * cross_test_params: enc/dec cross-attention test params; - this function requires kv_mmap field. - If None, KV cache memory map data - structures are treated as None - - Return: - - * AttentionMetadata structure - """ - - # Decoder self-attention memory mapping - # decoder_test_params is None signals encoder-only - # scenario, so kv_mmap is None - kv_mmap = None if decoder_test_params is None else decoder_test_params.kv_mmap - - # This function constructs metadata assuming no chunked prefill, - # i.e. 100% prefill tokens or 100% decode tokens - # - # - If is_prompt, num_prefills_or_decodes is the number of prefills - # and num_prefill_or_decode_tokens is the number of prefill tokens - # - If not is_prompt, num_prefills_or_decodes is the number of decodes - # and num_prefill_or_decode_tokens is the number of decode tokens - # - # seq_lens is None signals encoder-only - # scenario, in which case num_prefills_or_decodes and - # num_prefill_or_decode_tokens are unused - num_prefills_or_decodes = None if seq_lens is None else len(seq_lens) - - num_prefill_or_decode_tokens = ( - None if seq_lens is None else (sum(seq_lens) if is_prompt else len(seq_lens)) - ) - - # Seems for non-prefix-caching scenarios context_lens - # is never needed - context_lens = None - - if encoder_test_params is None: - encoder_seq_lens = None - num_encoder_tokens = None - else: - # Encoder/decoder or encoder-only models only: - # * Extract encoder input sequence lengths - assert encoder_test_params.packed_qkvo.packed_qkv is not None - encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - num_encoder_tokens = ( - None if encoder_seq_lens is None else (sum(encoder_seq_lens)) - ) - - # For encoder/decoder or encoder-only models only, extract *cross-attention* - # slot_mapping and block table (kv_mmap) - cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap - - attn_backend_obj = make_backend(attn_backend.name) - - if is_prompt: - # Prefill-phase scenario - - num_prefills = num_prefills_or_decodes - num_prefill_tokens = num_prefill_or_decode_tokens - num_decode_tokens = 0 - - ( - seq_lens_tensor, - context_lens_tensor, - _, - _, - seq_start_loc, - encoder_seq_lens_tensor, - encoder_seq_start_loc, - max_encoder_seq_len, - ) = _make_metadata_tensors( - seq_lens, context_lens, encoder_seq_lens, device=device - ) - return attn_backend_obj.make_metadata( - num_prefills=num_prefills, - slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), - enable_kv_scales_calculation=True, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - seq_start_loc=seq_start_loc, - max_prefill_seq_len=None if seq_lens is None else max(seq_lens), - max_decode_seq_len=0, - context_lens_tensor=context_lens_tensor, - block_tables=(None if kv_mmap is None else kv_mmap.block_tables), - use_cuda_graph=False, - num_encoder_tokens=num_encoder_tokens, - encoder_seq_lens=encoder_seq_lens, - encoder_seq_lens_tensor=encoder_seq_lens_tensor, - encoder_seq_start_loc=encoder_seq_start_loc, - max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=( - None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping - ), - cross_block_tables=( - None if cross_kv_mmap is None else cross_kv_mmap.block_tables - ), - ) - - else: # not is_prompt - # Decode-phase scenario - - assert kv_mmap is not None - assert num_prefill_or_decode_tokens is not None - assert seq_lens is not None - - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = num_prefill_or_decode_tokens - - ( - seq_lens_tensor, - context_lens_tensor, - _, - _, - seq_start_loc, - encoder_seq_lens_tensor, - encoder_seq_start_loc, - max_encoder_seq_len, - ) = _make_metadata_tensors( - seq_lens, context_lens, encoder_seq_lens, device=device - ) - - return attn_backend_obj.make_metadata( - num_prefills=num_prefills, - slot_mapping=kv_mmap.slot_mapping, - enable_kv_scales_calculation=True, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - seq_start_loc=seq_start_loc, - max_prefill_seq_len=0, - max_decode_seq_len=max(seq_lens), - max_decode_query_len=1, - context_lens_tensor=context_lens_tensor, - block_tables=kv_mmap.block_tables, - use_cuda_graph=False, - num_encoder_tokens=num_encoder_tokens, - encoder_seq_lens=encoder_seq_lens, - encoder_seq_lens_tensor=encoder_seq_lens_tensor, - encoder_seq_start_loc=encoder_seq_start_loc, - max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=( - None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping - ), - cross_block_tables=( - None if cross_kv_mmap is None else cross_kv_mmap.block_tables - ), - ) - - def assert_actual_matches_ideal( test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str ) -> None: @@ -1308,7 +1070,7 @@ def opcheck( raise_exception: bool = True, cond: bool = True, ) -> dict[str, str]: - with unittest.mock.patch("torch.allclose", new=fp8_allclose): + with patch("torch.allclose", new=fp8_allclose): return ( torch.library.opcheck( op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index e9c6a278a9411..b54eaf4e2872d 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -51,19 +51,10 @@ class AttentionBackend(ABC): def get_impl_cls() -> type["AttentionImpl"]: raise NotImplementedError - @staticmethod - @abstractmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - raise NotImplementedError - @classmethod def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: return cls.get_impl_cls().get_supported_kernel_block_size() - @classmethod - def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": - return cls.get_metadata_cls()(*args, **kwargs) - @staticmethod @abstractmethod def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 0d3e1729ff208..20d987fa2de3b 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -66,10 +66,6 @@ class TorchSDPABackend(AttentionBackend): def get_impl_cls() -> type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return TorchSDPAMetadata - @staticmethod def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]: return TorchSDPAMetadataBuilderV1 diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 226f2277ae985..15bb2f4a40acb 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,7 +11,6 @@ from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, MultipleOf, is_quantized_kv_cache, @@ -90,10 +89,6 @@ class FlashAttentionBackend(AttentionBackend): def get_impl_cls() -> type["FlashAttentionImpl"]: return FlashAttentionImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return FlashAttentionMetadata - @staticmethod def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: return FlashAttentionMetadataBuilder diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index ddc63b902dffb..683725b95819f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -195,10 +195,6 @@ class FlashInferBackend(AttentionBackend): def get_impl_cls() -> type["FlashInferImpl"]: return FlashInferImpl - @staticmethod - def get_metadata_cls() -> type["FlashInferMetadata"]: - return FlashInferMetadata - @staticmethod def get_builder_cls() -> type["FlashInferMetadataBuilder"]: return FlashInferMetadataBuilder diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 928252636d583..9af63831cecba 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -20,7 +20,6 @@ from torch.nn.attention.flex_attention import ( from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, is_quantized_kv_cache, ) @@ -89,10 +88,6 @@ class FlexAttentionBackend(AttentionBackend): def get_impl_cls() -> type["FlexAttentionImpl"]: return FlexAttentionImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return FlexAttentionMetadata - @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 6c8145b6847df..40ce12c4bd758 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -201,7 +201,6 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, - AttentionMetadata, MLAAttentionImpl, ) from vllm.attention.backends.utils import get_mla_dims @@ -307,10 +306,6 @@ class MLACommonBackend(AttentionBackend): def get_name() -> str: return "TRITON_MLA" - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return MLACommonMetadata - @staticmethod def get_builder_cls() -> type["MLACommonMetadataBuilder"]: return MLACommonMetadataBuilder diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 8a1e79baa87cd..79b89c7890a28 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -41,10 +41,6 @@ class FlashAttnMLABackend(MLACommonBackend): def get_name() -> str: return "FLASH_ATTN_MLA" - @staticmethod - def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: - return FlashAttnMLAMetadata - @staticmethod def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: return FlashAttnMLAMetadataBuilder diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index bc17307532093..708bb9d63839d 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -40,10 +40,6 @@ class FlashMLABackend(MLACommonBackend): def get_name() -> str: return "FLASHMLA" - @staticmethod - def get_metadata_cls() -> type["FlashMLAMetadata"]: - return FlashMLAMetadata - @staticmethod def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: return FlashMLAMetadataBuilder diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index bf8e4d5a62896..bf76549de1ce8 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -10,7 +10,6 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, - AttentionMetadata, ) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.flashmla import ( @@ -57,10 +56,6 @@ class FlashMLASparseBackend(AttentionBackend): def get_name() -> str: return "FLASHMLA_SPARSE" - @staticmethod - def get_metadata_cls() -> type[AttentionMetadata]: - return FlashMLASparseMetadata - @staticmethod def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: return FlashMLASparseMetadataBuilder diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 49009a939d0b5..f3c5bb7328712 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -7,7 +7,6 @@ import torch from vllm.attention.backends.abstract import ( AttentionBackend, - AttentionMetadata, MultipleOf, ) from vllm.config import VllmConfig @@ -24,10 +23,6 @@ logger = init_logger(__name__) class DeepseekV32IndexerBackend(AttentionBackend): - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return DeepseekV32IndexerMetadata - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 128] diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 71eac84b6f063..4ad7236eb1be3 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -35,10 +35,6 @@ class AiterMLABackend(MLACommonBackend): def get_impl_cls() -> type["AiterMLAImpl"]: return AiterMLAImpl - @staticmethod - def get_metadata_cls() -> type["AiterMLAMetadata"]: - return AiterMLAMetadata - @staticmethod def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: return AiterMLAMetadataBuilder diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 40a5517877967..525026bac5a7e 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -108,10 +108,6 @@ class PallasAttentionBackend(AttentionBackend): def get_impl_cls() -> type["PallasAttentionBackendImpl"]: return PallasAttentionBackendImpl - @staticmethod - def get_metadata_cls() -> type["PallasMetadata"]: - return PallasMetadata - @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 30e5cafe0c843..e8d3758a6395a 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -9,7 +9,6 @@ import torch from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, MultipleOf, ) @@ -479,10 +478,6 @@ class AiterFlashAttentionBackend(AttentionBackend): def get_impl_cls() -> type["AiterFlashAttentionImpl"]: return AiterFlashAttentionImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AiterFlashAttentionMetadata - @staticmethod def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]: return AiterFlashAttentionMetadataBuilder diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 27b072106268b..b2639c0df0412 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -5,7 +5,7 @@ import torch from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.attention.backends.abstract import AttentionType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -15,7 +15,6 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.rocm_attn import ( RocmAttentionBackend, RocmAttentionImpl, - RocmAttentionMetadata, RocmAttentionMetadataBuilder, ) @@ -33,10 +32,6 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]: return RocmAiterUnifiedAttentionImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return RocmAttentionMetadata - @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 8b7ce90a3ccae..57ba4dc78d9fd 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -10,7 +10,6 @@ import torch from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, ) from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode @@ -182,10 +181,6 @@ class RocmAttentionBackend(AttentionBackend): def get_impl_cls() -> type["RocmAttentionImpl"]: return RocmAttentionImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return RocmAttentionMetadata - @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index ee6ead9ad9b35..0c0222d6152fb 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -12,7 +12,6 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, MultipleOf, ) @@ -64,10 +63,6 @@ class TreeAttentionBackend(AttentionBackend): def get_impl_cls() -> type["TreeAttentionImpl"]: return TreeAttentionImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return TreeAttentionMetadata - @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b1d34dbfd1729..0590a87bf8e5f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -10,7 +10,6 @@ import torch from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, MultipleOf, ) @@ -176,10 +175,6 @@ class TritonAttentionBackend(AttentionBackend): def get_impl_cls() -> type["TritonAttentionImpl"]: return TritonAttentionImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return TritonAttentionMetadata - @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 457b15ebdd82f..81bdbd641429a 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -10,7 +10,6 @@ import torch from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, MultipleOf, ) @@ -105,10 +104,6 @@ class XFormersAttentionBackend(AttentionBackend): def get_impl_cls() -> type["XFormersAttentionImpl"]: return XFormersAttentionImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return XFormersAttentionMetadata - @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index de9f32687635e..26007d29d61b8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -20,7 +20,11 @@ from tqdm import tqdm import vllm.envs as envs from vllm.attention import Attention, AttentionType -from vllm.attention.backends.abstract import AttentionBackend, MultipleOf +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + MultipleOf, +) from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -82,7 +86,6 @@ from vllm.utils.torch_utils import ( kv_cache_dtype_str_to_dtype, supports_dynamo, ) -from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( AttentionCGSupport,