[V0 deprecation] Remove no longer used get_metadata_cls (#28370)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-11-10 01:32:09 -05:00 committed by GitHub
parent 03fa4d3fb3
commit e8697faf03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 9 additions and 332 deletions

View File

@ -4,24 +4,21 @@
import itertools import itertools
import random import random
import unittest
from collections.abc import Sequence from collections.abc import Sequence
from numbers import Number from numbers import Number
from typing import Any, NamedTuple from typing import Any, NamedTuple
from unittest.mock import patch
import pytest import pytest
import torch import torch
from torch._prims_common import TensorLikeType from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionType
from vllm.attention.backends.registry import _Backend
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils import ( from vllm.utils import (
STR_BACKEND_ENV_VAR, STR_BACKEND_ENV_VAR,
STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL,
) )
from vllm.utils.torch_utils import make_tensor_with_pad from vllm.utils.torch_utils import make_tensor_with_pad
@ -512,50 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
) )
def make_backend(backend_name: str) -> AttentionBackend:
"""
Construct the backend instance determined by the backend_name string
argument.
Note: at time of writing the Attention wrapper automatically selects
its own backend for Attention.forward(); so the backend instance which
you generate with this function is not meant to be used for *running*
inference, but rather for generating compatible metadata structures
using backend.make_metadata()
Returns:
* Backend instance
"""
if backend_name == STR_XFORMERS_ATTN_VAL:
from vllm.v1.attention.backends.xformers import XFormersAttentionBackend
return XFormersAttentionBackend()
if backend_name == STR_FLASH_ATTN_VAL:
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
return FlashAttentionBackend()
if backend_name == "TRITON_ATTN":
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
return TritonAttentionBackend()
if backend_name == "FLEX_ATTENTION":
from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend
return FlexAttentionBackend()
if backend_name == "TORCH_SDPA":
from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
return TorchSDPABackend()
if backend_name == "FLASHINFER":
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend()
raise AssertionError(f"Unrecognized backend_name {backend_name} for unit test")
def make_alibi_bias( def make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
num_kv_heads: int, num_kv_heads: int,
@ -877,197 +830,6 @@ def make_block_tables_slot_mapping(
return (block_tables_tensor, slot_mapping_list, max_block_idx) return (block_tables_tensor, slot_mapping_list, max_block_idx)
def make_test_metadata(
attn_backend: _Backend,
is_prompt: bool,
seq_lens: list[int] | None,
decoder_test_params: PhaseTestParameters | None,
device: torch.device | str,
encoder_test_params: PhaseTestParameters | None = None,
cross_test_params: PhaseTestParameters | None = None,
) -> AttentionMetadata:
"""
Construct fake attention metadata for a given test phase
(prefill-phase or decode-phase).
encoder_test_params and cross_test_params arguments allow encoder
attention and enc/dec cross-attention (respectively) to use distinct
metadata values from decoder self-attention (decoder_test_params.)
if encoder_test_params and cross_test_params are None, the attention
metadata will support decoder-only scenario.
Assumptions:
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
Arguments:
* attn_backend_name: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
this function requires
kv_mmap (memory mapping) field
* device: CPU or CUDA device
* encoder_test_params: encoder attention test params;
this function requires encoder query
sequence lengths field. If None,
encoder query sequence lengths are
treated as None
* cross_test_params: enc/dec cross-attention test params;
this function requires kv_mmap field.
If None, KV cache memory map data
structures are treated as None
Return:
* AttentionMetadata structure
"""
# Decoder self-attention memory mapping
# decoder_test_params is None signals encoder-only
# scenario, so kv_mmap is None
kv_mmap = None if decoder_test_params is None else decoder_test_params.kv_mmap
# This function constructs metadata assuming no chunked prefill,
# i.e. 100% prefill tokens or 100% decode tokens
#
# - If is_prompt, num_prefills_or_decodes is the number of prefills
# and num_prefill_or_decode_tokens is the number of prefill tokens
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
# and num_prefill_or_decode_tokens is the number of decode tokens
#
# seq_lens is None signals encoder-only
# scenario, in which case num_prefills_or_decodes and
# num_prefill_or_decode_tokens are unused
num_prefills_or_decodes = None if seq_lens is None else len(seq_lens)
num_prefill_or_decode_tokens = (
None if seq_lens is None else (sum(seq_lens) if is_prompt else len(seq_lens))
)
# Seems for non-prefix-caching scenarios context_lens
# is never needed
context_lens = None
if encoder_test_params is None:
encoder_seq_lens = None
num_encoder_tokens = None
else:
# Encoder/decoder or encoder-only models only:
# * Extract encoder input sequence lengths
assert encoder_test_params.packed_qkvo.packed_qkv is not None
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
num_encoder_tokens = (
None if encoder_seq_lens is None else (sum(encoder_seq_lens))
)
# For encoder/decoder or encoder-only models only, extract *cross-attention*
# slot_mapping and block table (kv_mmap)
cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap
attn_backend_obj = make_backend(attn_backend.name)
if is_prompt:
# Prefill-phase scenario
num_prefills = num_prefills_or_decodes
num_prefill_tokens = num_prefill_or_decode_tokens
num_decode_tokens = 0
(
seq_lens_tensor,
context_lens_tensor,
_,
_,
seq_start_loc,
encoder_seq_lens_tensor,
encoder_seq_start_loc,
max_encoder_seq_len,
) = _make_metadata_tensors(
seq_lens, context_lens, encoder_seq_lens, device=device
)
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=seq_start_loc,
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
max_decode_seq_len=0,
context_lens_tensor=context_lens_tensor,
block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
encoder_seq_start_loc=encoder_seq_start_loc,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(
None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping
),
cross_block_tables=(
None if cross_kv_mmap is None else cross_kv_mmap.block_tables
),
)
else: # not is_prompt
# Decode-phase scenario
assert kv_mmap is not None
assert num_prefill_or_decode_tokens is not None
assert seq_lens is not None
num_prefills = 0
num_prefill_tokens = 0
num_decode_tokens = num_prefill_or_decode_tokens
(
seq_lens_tensor,
context_lens_tensor,
_,
_,
seq_start_loc,
encoder_seq_lens_tensor,
encoder_seq_start_loc,
max_encoder_seq_len,
) = _make_metadata_tensors(
seq_lens, context_lens, encoder_seq_lens, device=device
)
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=seq_start_loc,
max_prefill_seq_len=0,
max_decode_seq_len=max(seq_lens),
max_decode_query_len=1,
context_lens_tensor=context_lens_tensor,
block_tables=kv_mmap.block_tables,
use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
encoder_seq_start_loc=encoder_seq_start_loc,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(
None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping
),
cross_block_tables=(
None if cross_kv_mmap is None else cross_kv_mmap.block_tables
),
)
def assert_actual_matches_ideal( def assert_actual_matches_ideal(
test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str
) -> None: ) -> None:
@ -1308,7 +1070,7 @@ def opcheck(
raise_exception: bool = True, raise_exception: bool = True,
cond: bool = True, cond: bool = True,
) -> dict[str, str]: ) -> dict[str, str]:
with unittest.mock.patch("torch.allclose", new=fp8_allclose): with patch("torch.allclose", new=fp8_allclose):
return ( return (
torch.library.opcheck( torch.library.opcheck(
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception

View File

@ -51,19 +51,10 @@ class AttentionBackend(ABC):
def get_impl_cls() -> type["AttentionImpl"]: def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError raise NotImplementedError
@staticmethod
@abstractmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
@classmethod @classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return cls.get_impl_cls().get_supported_kernel_block_size() return cls.get_impl_cls().get_supported_kernel_block_size()
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:

View File

@ -66,10 +66,6 @@ class TorchSDPABackend(AttentionBackend):
def get_impl_cls() -> type["TorchSDPABackendImpl"]: def get_impl_cls() -> type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl return TorchSDPABackendImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]: def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1 return TorchSDPAMetadataBuilderV1

View File

@ -11,7 +11,6 @@ from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
@ -90,10 +89,6 @@ class FlashAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["FlashAttentionImpl"]: def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder

View File

@ -195,10 +195,6 @@ class FlashInferBackend(AttentionBackend):
def get_impl_cls() -> type["FlashInferImpl"]: def get_impl_cls() -> type["FlashInferImpl"]:
return FlashInferImpl return FlashInferImpl
@staticmethod
def get_metadata_cls() -> type["FlashInferMetadata"]:
return FlashInferMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashInferMetadataBuilder"]: def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder return FlashInferMetadataBuilder

View File

@ -20,7 +20,6 @@ from torch.nn.attention.flex_attention import (
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata,
AttentionType, AttentionType,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
@ -89,10 +88,6 @@ class FlexAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["FlexAttentionImpl"]: def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl return FlexAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlexAttentionMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -201,7 +201,6 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, MLAAttentionImpl,
) )
from vllm.attention.backends.utils import get_mla_dims from vllm.attention.backends.utils import get_mla_dims
@ -307,10 +306,6 @@ class MLACommonBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA" return "TRITON_MLA"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return MLACommonMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["MLACommonMetadataBuilder"]: def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder return MLACommonMetadataBuilder

View File

@ -41,10 +41,6 @@ class FlashAttnMLABackend(MLACommonBackend):
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_MLA" return "FLASH_ATTN_MLA"
@staticmethod
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
return FlashAttnMLAMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
return FlashAttnMLAMetadataBuilder return FlashAttnMLAMetadataBuilder

View File

@ -40,10 +40,6 @@ class FlashMLABackend(MLACommonBackend):
def get_name() -> str: def get_name() -> str:
return "FLASHMLA" return "FLASHMLA"
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder return FlashMLAMetadataBuilder

View File

@ -10,7 +10,6 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
AttentionMetadata,
) )
from vllm.attention.backends.utils import get_mla_dims from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.flashmla import ( from vllm.attention.ops.flashmla import (
@ -57,10 +56,6 @@ class FlashMLASparseBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLASHMLA_SPARSE" return "FLASHMLA_SPARSE"
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
return FlashMLASparseMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
return FlashMLASparseMetadataBuilder return FlashMLASparseMetadataBuilder

View File

@ -7,7 +7,6 @@ import torch
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -24,10 +23,6 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend): class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return DeepseekV32IndexerMetadata
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 128] return [32, 64, 128]

View File

@ -35,10 +35,6 @@ class AiterMLABackend(MLACommonBackend):
def get_impl_cls() -> type["AiterMLAImpl"]: def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder return AiterMLAMetadataBuilder

View File

@ -108,10 +108,6 @@ class PallasAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["PallasAttentionBackendImpl"]: def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> type["PallasMetadata"]:
return PallasMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -9,7 +9,6 @@ import torch
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
) )
@ -479,10 +478,6 @@ class AiterFlashAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["AiterFlashAttentionImpl"]: def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
return AiterFlashAttentionImpl return AiterFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AiterFlashAttentionMetadata
@staticmethod @staticmethod
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]: def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
return AiterFlashAttentionMetadataBuilder return AiterFlashAttentionMetadataBuilder

View File

@ -5,7 +5,7 @@
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
@ -15,7 +15,6 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import ( from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend, RocmAttentionBackend,
RocmAttentionImpl, RocmAttentionImpl,
RocmAttentionMetadata,
RocmAttentionMetadataBuilder, RocmAttentionMetadataBuilder,
) )
@ -33,10 +32,6 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]: def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
return RocmAiterUnifiedAttentionImpl return RocmAiterUnifiedAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return RocmAttentionMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -10,7 +10,6 @@ import torch
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata,
AttentionType, AttentionType,
) )
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
@ -182,10 +181,6 @@ class RocmAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["RocmAttentionImpl"]: def get_impl_cls() -> type["RocmAttentionImpl"]:
return RocmAttentionImpl return RocmAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return RocmAttentionMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -12,7 +12,6 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
) )
@ -64,10 +63,6 @@ class TreeAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["TreeAttentionImpl"]: def get_impl_cls() -> type["TreeAttentionImpl"]:
return TreeAttentionImpl return TreeAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TreeAttentionMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -10,7 +10,6 @@ import torch
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
) )
@ -176,10 +175,6 @@ class TritonAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["TritonAttentionImpl"]: def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl return TritonAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TritonAttentionMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -10,7 +10,6 @@ import torch
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
) )
@ -105,10 +104,6 @@ class XFormersAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["XFormersAttentionImpl"]: def get_impl_cls() -> type["XFormersAttentionImpl"]:
return XFormersAttentionImpl return XFormersAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return XFormersAttentionMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -20,7 +20,11 @@ from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
MultipleOf,
)
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.compilation.monitor import set_cudagraph_capturing_enabled
@ -82,7 +86,6 @@ from vllm.utils.torch_utils import (
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
supports_dynamo, supports_dynamo,
) )
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,