mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:35:54 +08:00
[Attention] Move Backend enum into registry (#25893)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
ad2d788016
commit
2aaa423842
@ -11,8 +11,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from tests.v1.attention.utils import _Backend
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
PassConfig)
|
||||
|
||||
@ -8,11 +8,11 @@ import torch._dynamo
|
||||
|
||||
from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.models.utils import check_outputs_equal
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata)
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
|
||||
@ -10,8 +10,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.selector import _Backend, _cached_get_attn_backend
|
||||
from vllm.attention.selector import _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
|
||||
@ -15,10 +15,10 @@ from torch._prims_common import TensorLikeType
|
||||
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.platforms.interface import _Backend
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||
|
||||
|
||||
@ -8,11 +8,11 @@ import pytest
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
||||
|
||||
@ -6,12 +6,12 @@ from typing import Optional, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
@ -8,10 +8,11 @@ from typing import Optional, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
|
||||
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
@ -8,10 +8,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
|
||||
@ -6,10 +6,10 @@ from unittest import mock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
|
||||
@ -6,9 +6,10 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
|
||||
from tests.v1.attention.utils import (create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
|
||||
27
vllm/attention/backends/registry.py
Normal file
27
vllm/attention/backends/registry.py
Normal file
@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend registry"""
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
TRITON_ATTN = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
ROCM_AITER_MLA = enum.auto()
|
||||
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
FLASHINFER_MLA = enum.auto()
|
||||
TRITON_MLA = enum.auto()
|
||||
CUTLASS_MLA = enum.auto()
|
||||
FLASHMLA = enum.auto()
|
||||
FLASH_ATTN_MLA = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
FLEX_ATTENTION = enum.auto()
|
||||
TREE_ATTN = enum.auto()
|
||||
ROCM_ATTN = enum.auto()
|
||||
@ -10,6 +10,7 @@ import torch.nn.functional as F
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -11,8 +11,9 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -20,6 +20,7 @@ import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
@ -32,7 +33,7 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@ -619,8 +619,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# All possible options loaded dynamically from _Backend enum
|
||||
"VLLM_ATTENTION_BACKEND":
|
||||
env_with_choices("VLLM_ATTENTION_BACKEND", None,
|
||||
lambda: list(__import__('vllm.platforms.interface', \
|
||||
fromlist=['_Backend'])._Backend.__members__.keys())),
|
||||
lambda: list(__import__(
|
||||
'vllm.attention.backends.registry',
|
||||
fromlist=['_Backend'])._Backend.__members__.keys())),
|
||||
|
||||
# If set, vllm will use flashinfer sampler
|
||||
"VLLM_USE_FLASHINFER_SAMPLER":
|
||||
|
||||
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm
|
||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -38,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
||||
DotsVisionConfig)
|
||||
|
||||
@ -34,6 +34,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
@ -54,7 +55,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
|
||||
@ -46,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
|
||||
Glm4vVideoProcessor)
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
@ -69,7 +70,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
|
||||
BaseModelOutputWithPooling)
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -39,7 +40,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
@ -62,7 +63,6 @@ from vllm.multimodal.evs import (compute_mrope_for_media,
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
|
||||
Qwen2VLVideoProcessor)
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
|
||||
smart_resize as video_smart_resize)
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import _Backend
|
||||
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
|
||||
@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from vllm import envs
|
||||
from vllm.plugins import load_plugins_by_group
|
||||
from vllm.utils import resolve_obj_by_qualname, supports_xccl
|
||||
|
||||
from .interface import _Backend # noqa: F401
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,13 +15,15 @@ import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
@ -90,10 +92,11 @@ class CpuPlatform(Platform):
|
||||
return "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool, use_sparse: bool) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
|
||||
@ -20,10 +20,13 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless, import_pynvml
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -202,7 +205,8 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> _Backend:
|
||||
dtype: torch.dtype) -> "_Backend":
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
# For Blackwell GPUs, force TORCH_SDPA for now.
|
||||
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
||||
@ -230,6 +234,7 @@ class CudaPlatformBase(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink, use_sparse) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
|
||||
@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
else:
|
||||
_Backend = None
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
LoRARequest = None
|
||||
@ -38,30 +40,6 @@ def in_wsl() -> bool:
|
||||
return "microsoft" in " ".join(uname()).lower()
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
TRITON_ATTN = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
ROCM_AITER_MLA = enum.auto() # Supported by V1
|
||||
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
FLASHINFER_MLA = enum.auto()
|
||||
TRITON_MLA = enum.auto() # Supported by V1
|
||||
CUTLASS_MLA = enum.auto()
|
||||
FLASHMLA = enum.auto() # Supported by V1
|
||||
FLASH_ATTN_MLA = enum.auto() # Supported by V1
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
FLEX_ATTENTION = enum.auto()
|
||||
TREE_ATTN = enum.auto()
|
||||
ROCM_ATTN = enum.auto()
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
ROCM = enum.auto()
|
||||
@ -187,11 +165,12 @@ class Platform:
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> _Backend:
|
||||
dtype: torch.dtype) -> "_Backend":
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool, use_sparse: bool) -> str:
|
||||
|
||||
@ -14,10 +14,13 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -182,7 +185,8 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> _Backend:
|
||||
dtype: torch.dtype) -> "_Backend":
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
||||
and on_gfx9()):
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
@ -196,6 +200,7 @@ class RocmPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink, use_sparse) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on ROCm.")
|
||||
|
||||
@ -11,9 +11,10 @@ from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
||||
from vllm.pooling_params import PoolingParams
|
||||
else:
|
||||
@ -21,6 +22,7 @@ else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
_Backend = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -46,10 +48,11 @@ class TpuPlatform(Platform):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink, use_sparse) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on TPU.")
|
||||
|
||||
@ -10,13 +10,15 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
_Backend = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -33,10 +35,11 @@ class XPUPlatform(Platform):
|
||||
device_control_env_var: str = "ZE_AFFINITY_MASK"
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool, use_sparse) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on XPU.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user