[Attention] Move Backend enum into registry (#25893)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-10-02 23:32:24 -04:00 committed by GitHub
parent ad2d788016
commit 2aaa423842
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 99 additions and 66 deletions

View File

@ -11,8 +11,8 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)

View File

@ -8,11 +8,11 @@ import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass

View File

@ -10,8 +10,9 @@ from unittest.mock import patch
import pytest
import torch
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _Backend, _cached_get_attn_backend
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform

View File

@ -15,10 +15,10 @@ from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.backends.registry import _Backend
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

View File

@ -8,11 +8,11 @@ import pytest
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer

View File

@ -6,12 +6,12 @@ from typing import Optional, Union
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@ -8,10 +8,11 @@ from typing import Optional, Union
import pytest
import torch
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@ -8,10 +8,10 @@ import pytest
import torch
from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)

View File

@ -6,10 +6,10 @@ from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)

View File

@ -6,9 +6,10 @@ from typing import Optional
import torch
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
from tests.v1.attention.utils import (create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata

View File

@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry"""
import enum
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto()
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto()
FLASH_ATTN_MLA = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()

View File

@ -10,6 +10,7 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op
logger = init_logger(__name__)

View File

@ -11,8 +11,9 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__)

View File

@ -20,6 +20,7 @@ import torch
import zmq
from vllm import envs
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -32,7 +33,7 @@ from vllm.distributed.parallel_state import (
from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput

View File

@ -619,8 +619,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# All possible options loaded dynamically from _Backend enum
"VLLM_ATTENTION_BACKEND":
env_with_choices("VLLM_ATTENTION_BACKEND", None,
lambda: list(__import__('vllm.platforms.interface', \
fromlist=['_Backend'])._Backend.__members__.keys())),
lambda: list(__import__(
'vllm.attention.backends.registry',
fromlist=['_Backend'])._Backend.__members__.keys())),
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER":

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils
@ -38,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig)

View File

@ -34,6 +34,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
@ -54,7 +55,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -46,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
Glm4vVideoProcessor)
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
@ -69,7 +70,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@ -39,7 +40,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
@ -62,7 +63,6 @@ from vllm.multimodal.evs import (compute_mrope_for_media,
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

View File

@ -13,6 +13,7 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import _Backend
from .vision import get_vit_attn_backend

View File

@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
import torch
from transformers import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
logger = init_logger(__name__)

View File

@ -9,7 +9,6 @@ from vllm import envs
from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname, supports_xccl
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum
logger = logging.getLogger(__name__)

View File

@ -15,13 +15,15 @@ import torch
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import VllmConfig
else:
_Backend = None
VllmConfig = None
@ -90,10 +92,11 @@ class CpuPlatform(Platform):
return "cpu"
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str:
from vllm.attention.backends.registry import _Backend
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:

View File

@ -20,10 +20,13 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
else:
_Backend = None
logger = init_logger(__name__)
@ -202,7 +205,8 @@ class CudaPlatformBase(Platform):
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend:
dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
# For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
@ -230,6 +234,7 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_mla:
if not use_v1:
raise RuntimeError(

View File

@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser
else:
_Backend = None
ModelConfig = None
VllmConfig = None
LoRARequest = None
@ -38,30 +40,6 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower()
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
FLASH_ATTN_MLA = enum.auto() # Supported by V1
PALLAS = enum.auto()
IPEX = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
@ -187,11 +165,12 @@ class Platform:
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend:
dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
return _Backend.TORCH_SDPA
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str:

View File

@ -14,10 +14,13 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
else:
_Backend = None
logger = init_logger(__name__)
@ -182,7 +185,8 @@ class RocmPlatform(Platform):
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend:
dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
and on_gfx9()):
# Note: AITER FA is only supported for Qwen-VL models.
@ -196,6 +200,7 @@ class RocmPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on ROCm.")

View File

@ -11,9 +11,10 @@ from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
else:
@ -21,6 +22,7 @@ else:
ModelConfig = None
VllmConfig = None
PoolingParams = None
_Backend = None
logger = init_logger(__name__)
@ -46,10 +48,11 @@ class TpuPlatform(Platform):
]
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on TPU.")

View File

@ -10,13 +10,15 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None
_Backend = None
logger = init_logger(__name__)
@ -33,10 +35,11 @@ class XPUPlatform(Platform):
device_control_env_var: str = "ZE_AFFINITY_MASK"
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on XPU.")