diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index f9f146810924..3ecda1a8ec33 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -11,8 +11,8 @@ import pytest import torch from tests.quantization.utils import is_quant_method_supported -from tests.v1.attention.utils import _Backend from vllm import LLM, SamplingParams +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index eb8c49135428..077cf11d048a 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -8,11 +8,11 @@ import torch._dynamo from tests.compile.backend import LazyInitPass, TestBackend from tests.models.utils import check_outputs_equal -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata) +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm import LLM, SamplingParams from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index d37b968ed979..cea08e19f52d 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -10,8 +10,9 @@ from unittest.mock import patch import pytest import torch +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import MultiHeadAttention -from vllm.attention.selector import _Backend, _cached_get_attn_backend +from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 0fdaa600aefa..db6f29c28c95 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -15,10 +15,10 @@ from torch._prims_common import TensorLikeType from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType +from vllm.attention.backends.registry import _Backend from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) -from vllm.platforms.interface import _Backend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6c17be759ab6..24cdd8afbb3b 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -8,11 +8,11 @@ import pytest import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) +from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 228551573ba8..f2d0a5b2407a 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -6,12 +6,12 @@ from typing import Optional, Union import pytest import torch -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) from vllm import _custom_ops as ops +from vllm.attention.backends.registry import _Backend from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index adfe2b2db040..2bea45210ff3 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -8,10 +8,11 @@ from typing import Optional, Union import pytest import torch +from vllm.attention.backends.registry import _Backend from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, LoadConfig, ModelConfig, ModelDType, ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 49311c0005e7..938c6543e9b0 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -8,10 +8,10 @@ import pytest import torch from tests.utils import get_attn_backend_list_based_on_platform -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, get_attention_backend) +from vllm.attention.backends.registry import _Backend from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 5b9ccfc3f48b..dc4a56c66de6 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -6,10 +6,10 @@ from unittest import mock import pytest import torch -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, get_attention_backend) +from vllm.attention.backends.registry import _Backend from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 51a737496dff..ebb9a3d97861 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -6,9 +6,10 @@ from typing import Optional import torch -from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec, +from tests.v1.attention.utils import (create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) +from vllm.attention.backends.registry import _Backend from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py new file mode 100644 index 000000000000..6377e8619b3c --- /dev/null +++ b/vllm/attention/backends/registry.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention backend registry""" + +import enum + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + TRITON_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + ROCM_AITER_MLA = enum.auto() + ROCM_AITER_FA = enum.auto() # used for ViT attn backend + TORCH_SDPA = enum.auto() + FLASHINFER = enum.auto() + FLASHINFER_MLA = enum.auto() + TRITON_MLA = enum.auto() + CUTLASS_MLA = enum.auto() + FLASHMLA = enum.auto() + FLASH_ATTN_MLA = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + NO_ATTENTION = enum.auto() + FLEX_ATTENTION = enum.auto() + TREE_ATTN = enum.auto() + ROCM_ATTN = enum.auto() diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 4ce6a864d7ad..113602645e89 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -10,6 +10,7 @@ import torch.nn.functional as F import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.model_executor.models.vision import get_vit_attn_backend -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import GiB_bytes, direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 6f048e589f7f..d3214fecfa70 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -11,8 +11,9 @@ import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 55d87ea994b5..4706c5130899 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -20,6 +20,7 @@ import torch import zmq from vllm import envs +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -32,7 +33,7 @@ from vllm.distributed.parallel_state import ( from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput diff --git a/vllm/envs.py b/vllm/envs.py index 3d7d3c576dab..6dce4bd0f94e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -619,8 +619,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # All possible options loaded dynamically from _Backend enum "VLLM_ATTENTION_BACKEND": env_with_choices("VLLM_ATTENTION_BACKEND", None, - lambda: list(__import__('vllm.platforms.interface', \ - fromlist=['_Backend'])._Backend.__members__.keys())), + lambda: list(__import__( + 'vllm.attention.backends.registry', + fromlist=['_Backend'])._Backend.__members__.keys())), # If set, vllm will use flashinfer sampler "VLLM_USE_FLASHINFER_SAMPLER": diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index e68777aab6bf..2445f0d784f4 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import utils as dist_utils @@ -38,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, DotsVisionConfig) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index c62658fa4c21..0b8e24407602 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,6 +34,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from transformers import BatchFeature +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state @@ -54,7 +55,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 722f1e428be7..315a057e6a7d 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -46,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import ( Glm4vVideoProcessor) from transformers.video_utils import VideoMetadata +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, @@ -69,7 +70,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 10b5c45169f4..90de0582b94a 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) from transformers.utils import torch_int +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -39,7 +40,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index da3889d31a7d..a70df3b72be4 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state @@ -62,7 +63,6 @@ from vllm.multimodal.evs import (compute_mrope_for_media, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 38435a69444e..2ff79765d4be 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( Qwen2VLVideoProcessor) +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather @@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 00de89811cc7..fc8557131c3e 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( smart_resize as video_smart_resize) from transformers.video_utils import VideoMetadata +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 18de4b576c49..d111a10809e7 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -13,6 +13,7 @@ from torch.nn import functional as F from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.platforms import _Backend from .vision import get_vit_attn_backend diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 3d16d71e1764..2636942580fa 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol, import torch from transformers import PretrainedConfig +from vllm.attention.backends.registry import _Backend from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform logger = init_logger(__name__) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 9b64817da648..7549de480ee6 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -9,7 +9,6 @@ from vllm import envs from vllm.plugins import load_plugins_by_group from vllm.utils import resolve_obj_by_qualname, supports_xccl -from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum logger = logging.getLogger(__name__) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 0b26446a87d8..436e295e58e6 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -15,13 +15,15 @@ import torch from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend +from .interface import CpuArchEnum, Platform, PlatformEnum logger = init_logger(__name__) if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import VllmConfig else: + _Backend = None VllmConfig = None @@ -90,10 +92,11 @@ class CpuPlatform(Platform): return "cpu" @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool, use_sparse: bool) -> str: + from vllm.attention.backends.registry import _Backend if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a9a8d9ea2625..b7baa614957e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -20,10 +20,13 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless, import_pynvml -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig +else: + _Backend = None logger = init_logger(__name__) @@ -202,7 +205,8 @@ class CudaPlatformBase(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> _Backend: + dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend # For Blackwell GPUs, force TORCH_SDPA for now. # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 @@ -230,6 +234,7 @@ class CudaPlatformBase(Platform): def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink, use_sparse) -> str: + from vllm.attention.backends.registry import _Backend if use_mla: if not use_v1: raise RuntimeError( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 1691ad62650b..df1395fa842a 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import FlexibleArgumentParser else: + _Backend = None ModelConfig = None VllmConfig = None LoRARequest = None @@ -38,30 +40,6 @@ def in_wsl() -> bool: return "microsoft" in " ".join(uname()).lower() -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - TRITON_ATTN = enum.auto() - XFORMERS = enum.auto() - ROCM_FLASH = enum.auto() - ROCM_AITER_MLA = enum.auto() # Supported by V1 - ROCM_AITER_FA = enum.auto() # used for ViT attn backend - TORCH_SDPA = enum.auto() - FLASHINFER = enum.auto() - FLASHINFER_MLA = enum.auto() - TRITON_MLA = enum.auto() # Supported by V1 - CUTLASS_MLA = enum.auto() - FLASHMLA = enum.auto() # Supported by V1 - FLASH_ATTN_MLA = enum.auto() # Supported by V1 - PALLAS = enum.auto() - IPEX = enum.auto() - DUAL_CHUNK_FLASH_ATTN = enum.auto() - DIFFERENTIAL_FLASH_ATTN = enum.auto() - NO_ATTENTION = enum.auto() - FLEX_ATTENTION = enum.auto() - TREE_ATTN = enum.auto() - ROCM_ATTN = enum.auto() - - class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() @@ -187,11 +165,12 @@ class Platform: @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> _Backend: + dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend return _Backend.TORCH_SDPA @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool, use_sparse: bool) -> str: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 14762f1b7094..e12967ad2587 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -14,10 +14,13 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig +else: + _Backend = None logger = init_logger(__name__) @@ -182,7 +185,8 @@ class RocmPlatform(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> _Backend: + dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()): # Note: AITER FA is only supported for Qwen-VL models. @@ -196,6 +200,7 @@ class RocmPlatform(Platform): def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink, use_sparse) -> str: + from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError( "Sparse Attention is not supported on ROCm.") diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4a4931f7f009..91a01a4f4ee9 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -11,9 +11,10 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import Platform, PlatformEnum, _Backend +from .interface import Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams else: @@ -21,6 +22,7 @@ else: ModelConfig = None VllmConfig = None PoolingParams = None + _Backend = None logger = init_logger(__name__) @@ -46,10 +48,11 @@ class TpuPlatform(Platform): ] @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink, use_sparse) -> str: + from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError( "Sparse Attention is not supported on TPU.") diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 12d6a2a2d1ba..3ccbae58726f 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -10,13 +10,15 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig else: ModelConfig = None VllmConfig = None + _Backend = None logger = init_logger(__name__) @@ -33,10 +35,11 @@ class XPUPlatform(Platform): device_control_env_var: str = "ZE_AFFINITY_MASK" @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool, use_sparse) -> str: + from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError( "Sparse Attention is not supported on XPU.")