mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-01 23:10:13 +08:00
[Attention] Move Backend enum into registry (#25893)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
8db7b7f39c
commit
2ea7d48656
@ -11,8 +11,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from tests.v1.attention.utils import _Backend
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
PassConfig)
|
PassConfig)
|
||||||
|
|||||||
@ -8,11 +8,11 @@ import torch._dynamo
|
|||||||
|
|
||||||
from tests.compile.backend import LazyInitPass, TestBackend
|
from tests.compile.backend import LazyInitPass, TestBackend
|
||||||
from tests.models.utils import check_outputs_equal
|
from tests.models.utils import check_outputs_equal
|
||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||||
create_common_attn_metadata)
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||||
from vllm.compilation.fusion import QUANT_OPS
|
from vllm.compilation.fusion import QUANT_OPS
|
||||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||||
|
|||||||
@ -10,8 +10,9 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.attention.selector import _Backend, _cached_get_attn_backend
|
from vllm.attention.selector import _cached_get_attn_backend
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.cpu import CpuPlatform
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
from vllm.platforms.cuda import CudaPlatform
|
from vllm.platforms.cuda import CudaPlatform
|
||||||
|
|||||||
@ -15,10 +15,10 @@ from torch._prims_common import TensorLikeType
|
|||||||
|
|
||||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
from vllm.platforms.interface import _Backend
|
|
||||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||||
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||||
|
|
||||||
|
|||||||
@ -8,11 +8,11 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||||
|
|
||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||||
create_common_attn_metadata,
|
|
||||||
create_standard_kv_cache_spec,
|
create_standard_kv_cache_spec,
|
||||||
create_vllm_config,
|
create_vllm_config,
|
||||||
get_attention_backend)
|
get_attention_backend)
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
||||||
|
|||||||
@ -6,12 +6,12 @@ from typing import Optional, Union
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||||
create_common_attn_metadata,
|
|
||||||
create_standard_kv_cache_spec,
|
create_standard_kv_cache_spec,
|
||||||
create_vllm_config,
|
create_vllm_config,
|
||||||
get_attention_backend)
|
get_attention_backend)
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
|||||||
@ -8,10 +8,11 @@ from typing import Optional, Union
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
|
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
|
||||||
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
|
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
|
||||||
SchedulerConfig, VllmConfig)
|
SchedulerConfig, VllmConfig)
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import resolve_obj_by_qualname
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
|||||||
@ -8,10 +8,10 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.utils import get_attn_backend_list_based_on_platform
|
from tests.utils import get_attn_backend_list_based_on_platform
|
||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||||
create_common_attn_metadata,
|
|
||||||
create_standard_kv_cache_spec,
|
create_standard_kv_cache_spec,
|
||||||
get_attention_backend)
|
get_attention_backend)
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
|
|||||||
@ -6,10 +6,10 @@ from unittest import mock
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||||
create_common_attn_metadata,
|
|
||||||
create_standard_kv_cache_spec,
|
create_standard_kv_cache_spec,
|
||||||
get_attention_backend)
|
get_attention_backend)
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
|
|||||||
@ -6,9 +6,10 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
|
from tests.v1.attention.utils import (create_standard_kv_cache_spec,
|
||||||
create_vllm_config,
|
create_vllm_config,
|
||||||
get_attention_backend)
|
get_attention_backend)
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
|
|||||||
27
vllm/attention/backends/registry.py
Normal file
27
vllm/attention/backends/registry.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention backend registry"""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class _Backend(enum.Enum):
|
||||||
|
FLASH_ATTN = enum.auto()
|
||||||
|
TRITON_ATTN = enum.auto()
|
||||||
|
XFORMERS = enum.auto()
|
||||||
|
ROCM_FLASH = enum.auto()
|
||||||
|
ROCM_AITER_MLA = enum.auto()
|
||||||
|
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
||||||
|
TORCH_SDPA = enum.auto()
|
||||||
|
FLASHINFER = enum.auto()
|
||||||
|
FLASHINFER_MLA = enum.auto()
|
||||||
|
TRITON_MLA = enum.auto()
|
||||||
|
CUTLASS_MLA = enum.auto()
|
||||||
|
FLASHMLA = enum.auto()
|
||||||
|
FLASH_ATTN_MLA = enum.auto()
|
||||||
|
PALLAS = enum.auto()
|
||||||
|
IPEX = enum.auto()
|
||||||
|
NO_ATTENTION = enum.auto()
|
||||||
|
FLEX_ATTENTION = enum.auto()
|
||||||
|
TREE_ATTN = enum.auto()
|
||||||
|
ROCM_ATTN = enum.auto()
|
||||||
@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionType
|
from vllm.attention import AttentionType
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape)
|
GroupShape)
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -11,8 +11,9 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import torch
|
|||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
@ -32,7 +33,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.distributed.utils import divide
|
from vllm.distributed.utils import divide
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import make_zmq_path, make_zmq_socket
|
from vllm.utils import make_zmq_path, make_zmq_socket
|
||||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|||||||
@ -619,8 +619,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# All possible options loaded dynamically from _Backend enum
|
# All possible options loaded dynamically from _Backend enum
|
||||||
"VLLM_ATTENTION_BACKEND":
|
"VLLM_ATTENTION_BACKEND":
|
||||||
env_with_choices("VLLM_ATTENTION_BACKEND", None,
|
env_with_choices("VLLM_ATTENTION_BACKEND", None,
|
||||||
lambda: list(__import__('vllm.platforms.interface', \
|
lambda: list(__import__(
|
||||||
fromlist=['_Backend'])._Backend.__members__.keys())),
|
'vllm.attention.backends.registry',
|
||||||
|
fromlist=['_Backend'])._Backend.__members__.keys())),
|
||||||
|
|
||||||
# If set, vllm will use flashinfer sampler
|
# If set, vllm will use flashinfer sampler
|
||||||
"VLLM_USE_FLASHINFER_SAMPLER":
|
"VLLM_USE_FLASHINFER_SAMPLER":
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
|||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@ -38,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
|||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalDataDict
|
from vllm.multimodal.inputs import MultiModalDataDict
|
||||||
from vllm.platforms import _Backend
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
||||||
DotsVisionConfig)
|
DotsVisionConfig)
|
||||||
|
|||||||
@ -34,6 +34,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
@ -54,7 +55,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
|
|||||||
@ -46,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
|
|||||||
Glm4vVideoProcessor)
|
Glm4vVideoProcessor)
|
||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
@ -69,7 +70,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import _Backend
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
|
|||||||
BaseModelOutputWithPooling)
|
BaseModelOutputWithPooling)
|
||||||
from transformers.utils import torch_int
|
from transformers.utils import torch_int
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -39,7 +40,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import _Backend
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
|||||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
@ -62,7 +63,6 @@ from vllm.multimodal.evs import (compute_mrope_for_media,
|
|||||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||||
from vllm.multimodal.parse import MultiModalDataItems
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||||
from vllm.platforms import _Backend
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
|||||||
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
|
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
|
||||||
Qwen2VLVideoProcessor)
|
Qwen2VLVideoProcessor)
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||||
@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import _Backend
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
|
|||||||
smart_resize as video_smart_resize)
|
smart_resize as video_smart_resize)
|
||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptReplacement, PromptUpdate,
|
PromptReplacement, PromptUpdate,
|
||||||
PromptUpdateDetails)
|
PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import _Backend
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from torch.nn import functional as F
|
|||||||
from transformers import Siglip2VisionConfig
|
from transformers import Siglip2VisionConfig
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.platforms import _Backend
|
|
||||||
|
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_vit_attn_backend
|
||||||
|
|
||||||
|
|||||||
@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_gather)
|
tensor_model_parallel_all_gather)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from vllm import envs
|
|||||||
from vllm.plugins import load_plugins_by_group
|
from vllm.plugins import load_plugins_by_group
|
||||||
from vllm.utils import resolve_obj_by_qualname, supports_xccl
|
from vllm.utils import resolve_obj_by_qualname, supports_xccl
|
||||||
|
|
||||||
from .interface import _Backend # noqa: F401
|
|
||||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -15,13 +15,15 @@ import torch
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||||
|
|
||||||
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
|
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
else:
|
else:
|
||||||
|
_Backend = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
|
|
||||||
|
|
||||||
@ -90,10 +92,11 @@ class CpuPlatform(Platform):
|
|||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
block_size: int, use_v1: bool, use_mla: bool,
|
block_size: int, use_v1: bool, use_mla: bool,
|
||||||
has_sink: bool, use_sparse: bool) -> str:
|
has_sink: bool, use_sparse: bool) -> str:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||||
if use_mla:
|
if use_mla:
|
||||||
|
|||||||
@ -20,10 +20,13 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cuda_device_count_stateless, import_pynvml
|
from vllm.utils import cuda_device_count_stateless, import_pynvml
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
else:
|
||||||
|
_Backend = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -202,7 +205,8 @@ class CudaPlatformBase(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_vit_attn_backend(cls, head_size: int,
|
def get_vit_attn_backend(cls, head_size: int,
|
||||||
dtype: torch.dtype) -> _Backend:
|
dtype: torch.dtype) -> "_Backend":
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
|
||||||
# For Blackwell GPUs, force TORCH_SDPA for now.
|
# For Blackwell GPUs, force TORCH_SDPA for now.
|
||||||
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
||||||
@ -230,6 +234,7 @@ class CudaPlatformBase(Platform):
|
|||||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||||
has_sink, use_sparse) -> str:
|
has_sink, use_sparse) -> str:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
if use_mla:
|
if use_mla:
|
||||||
if not use_v1:
|
if not use_v1:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
else:
|
else:
|
||||||
|
_Backend = None
|
||||||
ModelConfig = None
|
ModelConfig = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
LoRARequest = None
|
LoRARequest = None
|
||||||
@ -38,30 +40,6 @@ def in_wsl() -> bool:
|
|||||||
return "microsoft" in " ".join(uname()).lower()
|
return "microsoft" in " ".join(uname()).lower()
|
||||||
|
|
||||||
|
|
||||||
class _Backend(enum.Enum):
|
|
||||||
FLASH_ATTN = enum.auto()
|
|
||||||
TRITON_ATTN = enum.auto()
|
|
||||||
XFORMERS = enum.auto()
|
|
||||||
ROCM_FLASH = enum.auto()
|
|
||||||
ROCM_AITER_MLA = enum.auto() # Supported by V1
|
|
||||||
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
|
||||||
TORCH_SDPA = enum.auto()
|
|
||||||
FLASHINFER = enum.auto()
|
|
||||||
FLASHINFER_MLA = enum.auto()
|
|
||||||
TRITON_MLA = enum.auto() # Supported by V1
|
|
||||||
CUTLASS_MLA = enum.auto()
|
|
||||||
FLASHMLA = enum.auto() # Supported by V1
|
|
||||||
FLASH_ATTN_MLA = enum.auto() # Supported by V1
|
|
||||||
PALLAS = enum.auto()
|
|
||||||
IPEX = enum.auto()
|
|
||||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
|
||||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
|
||||||
NO_ATTENTION = enum.auto()
|
|
||||||
FLEX_ATTENTION = enum.auto()
|
|
||||||
TREE_ATTN = enum.auto()
|
|
||||||
ROCM_ATTN = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class PlatformEnum(enum.Enum):
|
class PlatformEnum(enum.Enum):
|
||||||
CUDA = enum.auto()
|
CUDA = enum.auto()
|
||||||
ROCM = enum.auto()
|
ROCM = enum.auto()
|
||||||
@ -187,11 +165,12 @@ class Platform:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_vit_attn_backend(cls, head_size: int,
|
def get_vit_attn_backend(cls, head_size: int,
|
||||||
dtype: torch.dtype) -> _Backend:
|
dtype: torch.dtype) -> "_Backend":
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
return _Backend.TORCH_SDPA
|
return _Backend.TORCH_SDPA
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
block_size: int, use_v1: bool, use_mla: bool,
|
block_size: int, use_v1: bool, use_mla: bool,
|
||||||
has_sink: bool, use_sparse: bool) -> str:
|
has_sink: bool, use_sparse: bool) -> str:
|
||||||
|
|||||||
@ -14,10 +14,13 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cuda_device_count_stateless
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
else:
|
||||||
|
_Backend = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -182,7 +185,8 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_vit_attn_backend(cls, head_size: int,
|
def get_vit_attn_backend(cls, head_size: int,
|
||||||
dtype: torch.dtype) -> _Backend:
|
dtype: torch.dtype) -> "_Backend":
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
||||||
and on_gfx9()):
|
and on_gfx9()):
|
||||||
# Note: AITER FA is only supported for Qwen-VL models.
|
# Note: AITER FA is only supported for Qwen-VL models.
|
||||||
@ -196,6 +200,7 @@ class RocmPlatform(Platform):
|
|||||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||||
has_sink, use_sparse) -> str:
|
has_sink, use_sparse) -> str:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
if use_sparse:
|
if use_sparse:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Sparse Attention is not supported on ROCm.")
|
"Sparse Attention is not supported on ROCm.")
|
||||||
|
|||||||
@ -11,9 +11,10 @@ from vllm.logger import init_logger
|
|||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum, _Backend
|
from .interface import Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
else:
|
else:
|
||||||
@ -21,6 +22,7 @@ else:
|
|||||||
ModelConfig = None
|
ModelConfig = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
PoolingParams = None
|
PoolingParams = None
|
||||||
|
_Backend = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -46,10 +48,11 @@ class TpuPlatform(Platform):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
block_size: int, use_v1: bool, use_mla: bool,
|
block_size: int, use_v1: bool, use_mla: bool,
|
||||||
has_sink, use_sparse) -> str:
|
has_sink, use_sparse) -> str:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
if use_sparse:
|
if use_sparse:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Sparse Attention is not supported on TPU.")
|
"Sparse Attention is not supported on TPU.")
|
||||||
|
|||||||
@ -10,13 +10,15 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
else:
|
else:
|
||||||
ModelConfig = None
|
ModelConfig = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
|
_Backend = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -33,10 +35,11 @@ class XPUPlatform(Platform):
|
|||||||
device_control_env_var: str = "ZE_AFFINITY_MASK"
|
device_control_env_var: str = "ZE_AFFINITY_MASK"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
block_size: int, use_v1: bool, use_mla: bool,
|
block_size: int, use_v1: bool, use_mla: bool,
|
||||||
has_sink: bool, use_sparse) -> str:
|
has_sink: bool, use_sparse) -> str:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
if use_sparse:
|
if use_sparse:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Sparse Attention is not supported on XPU.")
|
"Sparse Attention is not supported on XPU.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user