mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:14:57 +08:00
[V0 deprecation] Remove no longer used get_metadata_cls (#28370)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
03fa4d3fb3
commit
e8697faf03
@ -4,24 +4,21 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import random
|
import random
|
||||||
import unittest
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch._prims_common import TensorLikeType
|
from torch._prims_common import TensorLikeType
|
||||||
|
|
||||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
from vllm.attention import AttentionType
|
||||||
from vllm.attention.backends.registry import _Backend
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
STR_BACKEND_ENV_VAR,
|
STR_BACKEND_ENV_VAR,
|
||||||
STR_FLASH_ATTN_VAL,
|
|
||||||
STR_XFORMERS_ATTN_VAL,
|
|
||||||
)
|
)
|
||||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||||
|
|
||||||
@ -512,50 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_backend(backend_name: str) -> AttentionBackend:
|
|
||||||
"""
|
|
||||||
Construct the backend instance determined by the backend_name string
|
|
||||||
argument.
|
|
||||||
|
|
||||||
Note: at time of writing the Attention wrapper automatically selects
|
|
||||||
its own backend for Attention.forward(); so the backend instance which
|
|
||||||
you generate with this function is not meant to be used for *running*
|
|
||||||
inference, but rather for generating compatible metadata structures
|
|
||||||
using backend.make_metadata()
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
* Backend instance
|
|
||||||
"""
|
|
||||||
if backend_name == STR_XFORMERS_ATTN_VAL:
|
|
||||||
from vllm.v1.attention.backends.xformers import XFormersAttentionBackend
|
|
||||||
|
|
||||||
return XFormersAttentionBackend()
|
|
||||||
if backend_name == STR_FLASH_ATTN_VAL:
|
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
|
||||||
|
|
||||||
return FlashAttentionBackend()
|
|
||||||
if backend_name == "TRITON_ATTN":
|
|
||||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
|
|
||||||
|
|
||||||
return TritonAttentionBackend()
|
|
||||||
if backend_name == "FLEX_ATTENTION":
|
|
||||||
from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend
|
|
||||||
|
|
||||||
return FlexAttentionBackend()
|
|
||||||
if backend_name == "TORCH_SDPA":
|
|
||||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
|
|
||||||
|
|
||||||
return TorchSDPABackend()
|
|
||||||
if backend_name == "FLASHINFER":
|
|
||||||
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
|
||||||
|
|
||||||
return FlashInferBackend()
|
|
||||||
|
|
||||||
raise AssertionError(f"Unrecognized backend_name {backend_name} for unit test")
|
|
||||||
|
|
||||||
|
|
||||||
def make_alibi_bias(
|
def make_alibi_bias(
|
||||||
alibi_slopes: torch.Tensor,
|
alibi_slopes: torch.Tensor,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
@ -877,197 +830,6 @@ def make_block_tables_slot_mapping(
|
|||||||
return (block_tables_tensor, slot_mapping_list, max_block_idx)
|
return (block_tables_tensor, slot_mapping_list, max_block_idx)
|
||||||
|
|
||||||
|
|
||||||
def make_test_metadata(
|
|
||||||
attn_backend: _Backend,
|
|
||||||
is_prompt: bool,
|
|
||||||
seq_lens: list[int] | None,
|
|
||||||
decoder_test_params: PhaseTestParameters | None,
|
|
||||||
device: torch.device | str,
|
|
||||||
encoder_test_params: PhaseTestParameters | None = None,
|
|
||||||
cross_test_params: PhaseTestParameters | None = None,
|
|
||||||
) -> AttentionMetadata:
|
|
||||||
"""
|
|
||||||
Construct fake attention metadata for a given test phase
|
|
||||||
(prefill-phase or decode-phase).
|
|
||||||
|
|
||||||
encoder_test_params and cross_test_params arguments allow encoder
|
|
||||||
attention and enc/dec cross-attention (respectively) to use distinct
|
|
||||||
metadata values from decoder self-attention (decoder_test_params.)
|
|
||||||
|
|
||||||
if encoder_test_params and cross_test_params are None, the attention
|
|
||||||
metadata will support decoder-only scenario.
|
|
||||||
|
|
||||||
Assumptions:
|
|
||||||
|
|
||||||
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
* attn_backend_name: Backend for sourcing attention kernels
|
|
||||||
* is_prompt: prefill if True, o/w decode
|
|
||||||
* seq_lens: list of token counts for each sequence
|
|
||||||
* decoder_test_params: decoder self-attention test params;
|
|
||||||
this function requires
|
|
||||||
kv_mmap (memory mapping) field
|
|
||||||
* device: CPU or CUDA device
|
|
||||||
* encoder_test_params: encoder attention test params;
|
|
||||||
this function requires encoder query
|
|
||||||
sequence lengths field. If None,
|
|
||||||
encoder query sequence lengths are
|
|
||||||
treated as None
|
|
||||||
* cross_test_params: enc/dec cross-attention test params;
|
|
||||||
this function requires kv_mmap field.
|
|
||||||
If None, KV cache memory map data
|
|
||||||
structures are treated as None
|
|
||||||
|
|
||||||
Return:
|
|
||||||
|
|
||||||
* AttentionMetadata structure
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Decoder self-attention memory mapping
|
|
||||||
# decoder_test_params is None signals encoder-only
|
|
||||||
# scenario, so kv_mmap is None
|
|
||||||
kv_mmap = None if decoder_test_params is None else decoder_test_params.kv_mmap
|
|
||||||
|
|
||||||
# This function constructs metadata assuming no chunked prefill,
|
|
||||||
# i.e. 100% prefill tokens or 100% decode tokens
|
|
||||||
#
|
|
||||||
# - If is_prompt, num_prefills_or_decodes is the number of prefills
|
|
||||||
# and num_prefill_or_decode_tokens is the number of prefill tokens
|
|
||||||
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
|
|
||||||
# and num_prefill_or_decode_tokens is the number of decode tokens
|
|
||||||
#
|
|
||||||
# seq_lens is None signals encoder-only
|
|
||||||
# scenario, in which case num_prefills_or_decodes and
|
|
||||||
# num_prefill_or_decode_tokens are unused
|
|
||||||
num_prefills_or_decodes = None if seq_lens is None else len(seq_lens)
|
|
||||||
|
|
||||||
num_prefill_or_decode_tokens = (
|
|
||||||
None if seq_lens is None else (sum(seq_lens) if is_prompt else len(seq_lens))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Seems for non-prefix-caching scenarios context_lens
|
|
||||||
# is never needed
|
|
||||||
context_lens = None
|
|
||||||
|
|
||||||
if encoder_test_params is None:
|
|
||||||
encoder_seq_lens = None
|
|
||||||
num_encoder_tokens = None
|
|
||||||
else:
|
|
||||||
# Encoder/decoder or encoder-only models only:
|
|
||||||
# * Extract encoder input sequence lengths
|
|
||||||
assert encoder_test_params.packed_qkvo.packed_qkv is not None
|
|
||||||
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
|
|
||||||
num_encoder_tokens = (
|
|
||||||
None if encoder_seq_lens is None else (sum(encoder_seq_lens))
|
|
||||||
)
|
|
||||||
|
|
||||||
# For encoder/decoder or encoder-only models only, extract *cross-attention*
|
|
||||||
# slot_mapping and block table (kv_mmap)
|
|
||||||
cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap
|
|
||||||
|
|
||||||
attn_backend_obj = make_backend(attn_backend.name)
|
|
||||||
|
|
||||||
if is_prompt:
|
|
||||||
# Prefill-phase scenario
|
|
||||||
|
|
||||||
num_prefills = num_prefills_or_decodes
|
|
||||||
num_prefill_tokens = num_prefill_or_decode_tokens
|
|
||||||
num_decode_tokens = 0
|
|
||||||
|
|
||||||
(
|
|
||||||
seq_lens_tensor,
|
|
||||||
context_lens_tensor,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
seq_start_loc,
|
|
||||||
encoder_seq_lens_tensor,
|
|
||||||
encoder_seq_start_loc,
|
|
||||||
max_encoder_seq_len,
|
|
||||||
) = _make_metadata_tensors(
|
|
||||||
seq_lens, context_lens, encoder_seq_lens, device=device
|
|
||||||
)
|
|
||||||
return attn_backend_obj.make_metadata(
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
|
||||||
enable_kv_scales_calculation=True,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
|
||||||
seq_start_loc=seq_start_loc,
|
|
||||||
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
|
||||||
max_decode_seq_len=0,
|
|
||||||
context_lens_tensor=context_lens_tensor,
|
|
||||||
block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
|
|
||||||
use_cuda_graph=False,
|
|
||||||
num_encoder_tokens=num_encoder_tokens,
|
|
||||||
encoder_seq_lens=encoder_seq_lens,
|
|
||||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
|
||||||
encoder_seq_start_loc=encoder_seq_start_loc,
|
|
||||||
max_encoder_seq_len=max_encoder_seq_len,
|
|
||||||
cross_slot_mapping=(
|
|
||||||
None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping
|
|
||||||
),
|
|
||||||
cross_block_tables=(
|
|
||||||
None if cross_kv_mmap is None else cross_kv_mmap.block_tables
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # not is_prompt
|
|
||||||
# Decode-phase scenario
|
|
||||||
|
|
||||||
assert kv_mmap is not None
|
|
||||||
assert num_prefill_or_decode_tokens is not None
|
|
||||||
assert seq_lens is not None
|
|
||||||
|
|
||||||
num_prefills = 0
|
|
||||||
num_prefill_tokens = 0
|
|
||||||
num_decode_tokens = num_prefill_or_decode_tokens
|
|
||||||
|
|
||||||
(
|
|
||||||
seq_lens_tensor,
|
|
||||||
context_lens_tensor,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
seq_start_loc,
|
|
||||||
encoder_seq_lens_tensor,
|
|
||||||
encoder_seq_start_loc,
|
|
||||||
max_encoder_seq_len,
|
|
||||||
) = _make_metadata_tensors(
|
|
||||||
seq_lens, context_lens, encoder_seq_lens, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_backend_obj.make_metadata(
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
slot_mapping=kv_mmap.slot_mapping,
|
|
||||||
enable_kv_scales_calculation=True,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
|
||||||
seq_start_loc=seq_start_loc,
|
|
||||||
max_prefill_seq_len=0,
|
|
||||||
max_decode_seq_len=max(seq_lens),
|
|
||||||
max_decode_query_len=1,
|
|
||||||
context_lens_tensor=context_lens_tensor,
|
|
||||||
block_tables=kv_mmap.block_tables,
|
|
||||||
use_cuda_graph=False,
|
|
||||||
num_encoder_tokens=num_encoder_tokens,
|
|
||||||
encoder_seq_lens=encoder_seq_lens,
|
|
||||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
|
||||||
encoder_seq_start_loc=encoder_seq_start_loc,
|
|
||||||
max_encoder_seq_len=max_encoder_seq_len,
|
|
||||||
cross_slot_mapping=(
|
|
||||||
None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping
|
|
||||||
),
|
|
||||||
cross_block_tables=(
|
|
||||||
None if cross_kv_mmap is None else cross_kv_mmap.block_tables
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_actual_matches_ideal(
|
def assert_actual_matches_ideal(
|
||||||
test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str
|
test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -1308,7 +1070,7 @@ def opcheck(
|
|||||||
raise_exception: bool = True,
|
raise_exception: bool = True,
|
||||||
cond: bool = True,
|
cond: bool = True,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
with patch("torch.allclose", new=fp8_allclose):
|
||||||
return (
|
return (
|
||||||
torch.library.opcheck(
|
torch.library.opcheck(
|
||||||
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
||||||
|
|||||||
@ -51,19 +51,10 @@ class AttentionBackend(ABC):
|
|||||||
def get_impl_cls() -> type["AttentionImpl"]:
|
def get_impl_cls() -> type["AttentionImpl"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
|
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
|
||||||
return cls.get_impl_cls().get_supported_kernel_block_size()
|
return cls.get_impl_cls().get_supported_kernel_block_size()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
|
||||||
return cls.get_metadata_cls()(*args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||||
|
|||||||
@ -66,10 +66,6 @@ class TorchSDPABackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
||||||
return TorchSDPABackendImpl
|
return TorchSDPABackendImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return TorchSDPAMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
|
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
|
||||||
return TorchSDPAMetadataBuilderV1
|
return TorchSDPAMetadataBuilderV1
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from vllm import envs
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadata,
|
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
is_quantized_kv_cache,
|
is_quantized_kv_cache,
|
||||||
@ -90,10 +89,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||||
return FlashAttentionImpl
|
return FlashAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return FlashAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
|
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
|
||||||
return FlashAttentionMetadataBuilder
|
return FlashAttentionMetadataBuilder
|
||||||
|
|||||||
@ -195,10 +195,6 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["FlashInferImpl"]:
|
def get_impl_cls() -> type["FlashInferImpl"]:
|
||||||
return FlashInferImpl
|
return FlashInferImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["FlashInferMetadata"]:
|
|
||||||
return FlashInferMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
|
def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
|
||||||
return FlashInferMetadataBuilder
|
return FlashInferMetadataBuilder
|
||||||
|
|||||||
@ -20,7 +20,6 @@ from torch.nn.attention.flex_attention import (
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadata,
|
|
||||||
AttentionType,
|
AttentionType,
|
||||||
is_quantized_kv_cache,
|
is_quantized_kv_cache,
|
||||||
)
|
)
|
||||||
@ -89,10 +88,6 @@ class FlexAttentionBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
||||||
return FlexAttentionImpl
|
return FlexAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return FlexAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
|||||||
@ -201,7 +201,6 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionMetadata,
|
|
||||||
MLAAttentionImpl,
|
MLAAttentionImpl,
|
||||||
)
|
)
|
||||||
from vllm.attention.backends.utils import get_mla_dims
|
from vllm.attention.backends.utils import get_mla_dims
|
||||||
@ -307,10 +306,6 @@ class MLACommonBackend(AttentionBackend):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_MLA"
|
return "TRITON_MLA"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return MLACommonMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
|
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
|
||||||
return MLACommonMetadataBuilder
|
return MLACommonMetadataBuilder
|
||||||
|
|||||||
@ -41,10 +41,6 @@ class FlashAttnMLABackend(MLACommonBackend):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASH_ATTN_MLA"
|
return "FLASH_ATTN_MLA"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
|
|
||||||
return FlashAttnMLAMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
|
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
|
||||||
return FlashAttnMLAMetadataBuilder
|
return FlashAttnMLAMetadataBuilder
|
||||||
|
|||||||
@ -40,10 +40,6 @@ class FlashMLABackend(MLACommonBackend):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHMLA"
|
return "FLASHMLA"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
|
||||||
return FlashMLAMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||||
return FlashMLAMetadataBuilder
|
return FlashMLAMetadataBuilder
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionMetadata,
|
|
||||||
)
|
)
|
||||||
from vllm.attention.backends.utils import get_mla_dims
|
from vllm.attention.backends.utils import get_mla_dims
|
||||||
from vllm.attention.ops.flashmla import (
|
from vllm.attention.ops.flashmla import (
|
||||||
@ -57,10 +56,6 @@ class FlashMLASparseBackend(AttentionBackend):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHMLA_SPARSE"
|
return "FLASHMLA_SPARSE"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type[AttentionMetadata]:
|
|
||||||
return FlashMLASparseMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
||||||
return FlashMLASparseMetadataBuilder
|
return FlashMLASparseMetadataBuilder
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionMetadata,
|
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -24,10 +23,6 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return DeepseekV32IndexerMetadata
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [32, 64, 128]
|
return [32, 64, 128]
|
||||||
|
|||||||
@ -35,10 +35,6 @@ class AiterMLABackend(MLACommonBackend):
|
|||||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||||
return AiterMLAImpl
|
return AiterMLAImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AiterMLAMetadata"]:
|
|
||||||
return AiterMLAMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
||||||
return AiterMLAMetadataBuilder
|
return AiterMLAMetadataBuilder
|
||||||
|
|||||||
@ -108,10 +108,6 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||||
return PallasAttentionBackendImpl
|
return PallasAttentionBackendImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["PallasMetadata"]:
|
|
||||||
return PallasMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import torch
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadata,
|
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
@ -479,10 +478,6 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
||||||
return AiterFlashAttentionImpl
|
return AiterFlashAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return AiterFlashAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
|
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
|
||||||
return AiterFlashAttentionMetadataBuilder
|
return AiterFlashAttentionMetadataBuilder
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
@ -15,7 +15,6 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
|||||||
from vllm.v1.attention.backends.rocm_attn import (
|
from vllm.v1.attention.backends.rocm_attn import (
|
||||||
RocmAttentionBackend,
|
RocmAttentionBackend,
|
||||||
RocmAttentionImpl,
|
RocmAttentionImpl,
|
||||||
RocmAttentionMetadata,
|
|
||||||
RocmAttentionMetadataBuilder,
|
RocmAttentionMetadataBuilder,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,10 +32,6 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
|||||||
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
|
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
|
||||||
return RocmAiterUnifiedAttentionImpl
|
return RocmAiterUnifiedAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return RocmAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import torch
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadata,
|
|
||||||
AttentionType,
|
AttentionType,
|
||||||
)
|
)
|
||||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||||
@ -182,10 +181,6 @@ class RocmAttentionBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||||
return RocmAttentionImpl
|
return RocmAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return RocmAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadata,
|
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
@ -64,10 +63,6 @@ class TreeAttentionBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||||
return TreeAttentionImpl
|
return TreeAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return TreeAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import torch
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadata,
|
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
@ -176,10 +175,6 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||||
return TritonAttentionImpl
|
return TritonAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return TritonAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import torch
|
|||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadata,
|
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
@ -105,10 +104,6 @@ class XFormersAttentionBackend(AttentionBackend):
|
|||||||
def get_impl_cls() -> type["XFormersAttentionImpl"]:
|
def get_impl_cls() -> type["XFormersAttentionImpl"]:
|
||||||
return XFormersAttentionImpl
|
return XFormersAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
||||||
return XFormersAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
|||||||
@ -20,7 +20,11 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionMetadata,
|
||||||
|
MultipleOf,
|
||||||
|
)
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||||
@ -82,7 +86,6 @@ from vllm.utils.torch_utils import (
|
|||||||
kv_cache_dtype_str_to_dtype,
|
kv_cache_dtype_str_to_dtype,
|
||||||
supports_dynamo,
|
supports_dynamo,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user