[Attention] Move Backend enum into registry (#25893)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Matthew Bonanni 2025-10-02 23:32:24 -04:00 committed by yewentao256
parent 8db7b7f39c
commit 2ea7d48656
31 changed files with 99 additions and 66 deletions

View File

@ -11,8 +11,8 @@ import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig) PassConfig)

View File

@ -8,11 +8,11 @@ import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
create_common_attn_metadata)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass

View File

@ -10,8 +10,9 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _Backend, _cached_get_attn_backend from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform

View File

@ -15,10 +15,10 @@ from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.backends.registry import _Backend
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

View File

@ -8,11 +8,11 @@ import pytest
import torch import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
get_attention_backend) get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer

View File

@ -6,12 +6,12 @@ from typing import Optional, Union
import pytest import pytest
import torch import torch
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
get_attention_backend) get_attention_backend)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@ -8,10 +8,11 @@ from typing import Optional, Union
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig, LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig) SchedulerConfig, VllmConfig)
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@ -8,10 +8,10 @@ import pytest
import torch import torch
from tests.utils import get_attn_backend_list_based_on_platform from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
get_attention_backend) get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig) VllmConfig)

View File

@ -6,10 +6,10 @@ from unittest import mock
import pytest import pytest
import torch import torch
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
get_attention_backend) get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig) VllmConfig)

View File

@ -6,9 +6,10 @@ from typing import Optional
import torch import torch
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec, from tests.v1.attention.utils import (create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
get_attention_backend) get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata

View File

@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry"""
import enum
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto()
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto()
FLASH_ATTN_MLA = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()

View File

@ -10,6 +10,7 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape) GroupShape)
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op from vllm.utils import GiB_bytes, direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -11,8 +11,9 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -20,6 +20,7 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -32,7 +33,7 @@ from vllm.distributed.parallel_state import (
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.utils import make_zmq_path, make_zmq_socket from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput

View File

@ -619,8 +619,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# All possible options loaded dynamically from _Backend enum # All possible options loaded dynamically from _Backend enum
"VLLM_ATTENTION_BACKEND": "VLLM_ATTENTION_BACKEND":
env_with_choices("VLLM_ATTENTION_BACKEND", None, env_with_choices("VLLM_ATTENTION_BACKEND", None,
lambda: list(__import__('vllm.platforms.interface', \ lambda: list(__import__(
fromlist=['_Backend'])._Backend.__members__.keys())), 'vllm.attention.backends.registry',
fromlist=['_Backend'])._Backend.__members__.keys())),
# If set, vllm will use flashinfer sampler # If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER": "VLLM_USE_FLASHINFER_SAMPLER":

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
@ -38,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.inputs import MultiModalDataDict
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig) DotsVisionConfig)

View File

@ -34,6 +34,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
@ -54,7 +55,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -46,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
Glm4vVideoProcessor) Glm4vVideoProcessor)
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import (get_tensor_model_parallel_world_size,
@ -69,7 +70,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling) BaseModelOutputWithPooling)
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -39,7 +40,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
@ -62,7 +63,6 @@ from vllm.multimodal.evs import (compute_mrope_for_media,
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor) Qwen2VLVideoProcessor)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape

View File

@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize) smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of

View File

@ -13,6 +13,7 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import _Backend
from .vision import get_vit_attn_backend from .vision import get_vit_attn_backend

View File

@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -9,7 +9,6 @@ from vllm import envs
from vllm.plugins import load_plugins_by_group from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname, supports_xccl from vllm.utils import resolve_obj_by_qualname, supports_xccl
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum from .interface import CpuArchEnum, Platform, PlatformEnum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -15,13 +15,15 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
_Backend = None
VllmConfig = None VllmConfig = None
@ -90,10 +92,11 @@ class CpuPlatform(Platform):
return "cpu" return "cpu"
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str: has_sink: bool, use_sparse: bool) -> str:
from vllm.attention.backends.registry import _Backend
if selected_backend and selected_backend != _Backend.TORCH_SDPA: if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:

View File

@ -20,10 +20,13 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, import_pynvml from vllm.utils import cuda_device_count_stateless, import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
else:
_Backend = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -202,7 +205,8 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend: dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
# For Blackwell GPUs, force TORCH_SDPA for now. # For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
@ -230,6 +234,7 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str: has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_mla: if use_mla:
if not use_v1: if not use_v1:
raise RuntimeError( raise RuntimeError(

View File

@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
else: else:
_Backend = None
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
LoRARequest = None LoRARequest = None
@ -38,30 +40,6 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
FLASH_ATTN_MLA = enum.auto() # Supported by V1
PALLAS = enum.auto()
IPEX = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):
CUDA = enum.auto() CUDA = enum.auto()
ROCM = enum.auto() ROCM = enum.auto()
@ -187,11 +165,12 @@ class Platform:
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend: dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
return _Backend.TORCH_SDPA return _Backend.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str: has_sink: bool, use_sparse: bool) -> str:

View File

@ -14,10 +14,13 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
else:
_Backend = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -182,7 +185,8 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend: dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
and on_gfx9()): and on_gfx9()):
# Note: AITER FA is only supported for Qwen-VL models. # Note: AITER FA is only supported for Qwen-VL models.
@ -196,6 +200,7 @@ class RocmPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str: has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse: if use_sparse:
raise NotImplementedError( raise NotImplementedError(
"Sparse Attention is not supported on ROCm.") "Sparse Attention is not supported on ROCm.")

View File

@ -11,9 +11,10 @@ from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
else: else:
@ -21,6 +22,7 @@ else:
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
PoolingParams = None PoolingParams = None
_Backend = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -46,10 +48,11 @@ class TpuPlatform(Platform):
] ]
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink, use_sparse) -> str: has_sink, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse: if use_sparse:
raise NotImplementedError( raise NotImplementedError(
"Sparse Attention is not supported on TPU.") "Sparse Attention is not supported on TPU.")

View File

@ -10,13 +10,15 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
else: else:
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
_Backend = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -33,10 +35,11 @@ class XPUPlatform(Platform):
device_control_env_var: str = "ZE_AFFINITY_MASK" device_control_env_var: str = "ZE_AFFINITY_MASK"
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse) -> str: has_sink: bool, use_sparse) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse: if use_sparse:
raise NotImplementedError( raise NotImplementedError(
"Sparse Attention is not supported on XPU.") "Sparse Attention is not supported on XPU.")