[Attention] Update attention imports (#29540)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-11-27 11:19:09 -05:00 committed by GitHub
parent cd007a53b4
commit fc1d8be3dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 63 additions and 126 deletions

View File

@ -139,14 +139,13 @@ def test_standard_attention_backend_selection(
import importlib
import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
importlib.reload(envs)
# Convert string backend to enum if provided
backend_enum = None
if selected_backend:
backend_enum = getattr(_Backend, selected_backend)
backend_enum = getattr(AttentionBackendEnum, selected_backend)
# Get the backend class path
from vllm.platforms.rocm import RocmPlatform
@ -253,7 +252,6 @@ def test_mla_backend_selection(
import importlib
import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
importlib.reload(envs)
@ -269,7 +267,7 @@ def test_mla_backend_selection(
# Convert string backend to enum if provided
backend_enum = None
if selected_backend:
backend_enum = getattr(_Backend, selected_backend)
backend_enum = getattr(AttentionBackendEnum, selected_backend)
from vllm.platforms.rocm import RocmPlatform
@ -301,7 +299,6 @@ def test_mla_backend_selection(
def test_aiter_fa_requires_gfx9(mock_vllm_config):
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
from vllm.attention.backends.registry import _Backend
from vllm.platforms.rocm import RocmPlatform
# Mock on_gfx9 to return False
@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
),
):
RocmPlatform.get_attn_backend_cls(
selected_backend=_Backend.ROCM_AITER_FA,
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
head_size=128,
dtype=torch.float16,
kv_cache_dtype="auto",

View File

@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass

View File

@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
if TYPE_CHECKING:
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
@ -178,8 +177,6 @@ class AttentionBackend(ABC):
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention.backends.abstract import AttentionType
return attn_type == AttentionType.DECODER
@classmethod
@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]):
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, quant_key: QuantKey):
def fused_output_quant_supported(self, quant_key: "QuantKey"):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
) -> None:
raise NotImplementedError

View File

@ -5,6 +5,7 @@ import functools
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec,
)
from ..layer import Attention
@functools.lru_cache
def create_chunked_local_attention_backend(

View File

@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
@ -53,7 +54,6 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -61,7 +61,6 @@ if TYPE_CHECKING:
else:
PretrainedConfig = Any
AttentionBackendEnum = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)

View File

@ -2,19 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
from typing import Any, Literal, TypeAlias
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
else:
AttentionBackendEnum = Any
@dataclass
class BaseDummyOptions:
@ -170,9 +166,6 @@ class MultiModalConfig:
def _validate_mm_encoder_attn_backend(
cls, value: str | AttentionBackendEnum | None
) -> AttentionBackendEnum | None:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum
if isinstance(value, str) and value.upper() == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "

View File

@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
return
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
"""
Initialize with a single KV cache tensor used by all layers.

View File

@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
@ -45,7 +46,6 @@ from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
# This connector doesn't save KV cache (benchmarking only)

View File

@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@ -17,7 +18,6 @@ from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""

View File

@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.version import __version__ as VLLM_VERSION
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_manager import KVCacheManager
@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl:
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""Start saving the a layer of KV cache from vLLM's paged buffer

View File

@ -10,6 +10,7 @@ import zmq
from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
for c in self._connectors:

View File

@ -20,7 +20,7 @@ import torch
import zmq
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
@ -51,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
@ -308,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""NixlConnector does not save explicitly."""

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@ -19,7 +20,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer

View File

@ -5,19 +5,17 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import Any, NamedTuple
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
@ -195,7 +193,7 @@ class ForwardContext:
for each microbatch.
Set dynamically for each forward pass
"""
attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]]
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass

View File

@ -3,14 +3,11 @@
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC):
"""
@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
"""
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer."""
pass

View File

@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
),
)
def get_attn_backend(self) -> type["AttentionBackend"]:
def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type)

View File

@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)

View File

@ -14,6 +14,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.layer import Attention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig):
def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod,
XPUFp8MoEMethod,
@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase):

View File

@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention):
return self.KVCacheMethodCls(self)

View File

@ -8,6 +8,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@ -184,8 +185,6 @@ class Mxfp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,

View File

@ -8,6 +8,7 @@ import regex as re
import torch
from torch.nn.parameter import Parameter
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
LinearBase,
@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase):

View File

@ -7,6 +7,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import torch
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(

View File

@ -14,6 +14,7 @@ import regex as re
import torch
from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from .interface import CpuArchEnum, Platform, PlatformEnum
@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
AttentionBackendEnum = None
VllmConfig = None
@ -135,8 +134,6 @@ class CpuPlatform(Platform):
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:

View File

@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless
@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
AttentionBackendEnum = None
VllmConfig = None
CacheDType = None
@ -48,8 +48,6 @@ def _get_backend_priorities(
device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
from vllm.attention.backends.registry import AttentionBackendEnum
if use_mla:
if device_capability.major == 10:
return [
@ -265,8 +263,6 @@ class CudaPlatformBase(Platform):
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
# Try FlashAttention first
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
@ -335,8 +331,6 @@ class CudaPlatformBase(Platform):
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.abstract import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER

View File

@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
@ -226,9 +226,6 @@ class Platform:
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Import AttentionBackendEnum here to avoid circular import.
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.TORCH_SDPA
@classmethod

View File

@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
AttentionBackendEnum = None
logger = init_logger(__name__)
@ -196,7 +194,6 @@ class RocmPlatform(Platform):
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
@ -222,7 +219,6 @@ class RocmPlatform(Platform):
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
if kv_cache_dtype.startswith("fp8"):

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
import torch
from tpu_info import device
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
@ -26,7 +26,6 @@ else:
BlockSize = None
VllmConfig = None
PoolingParams = None
AttentionBackendEnum = None
ParamsType = None
logger = init_logger(__name__)
@ -67,8 +66,6 @@ class TpuPlatform(Platform):
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:

View File

@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
VllmConfig = None
AttentionBackendEnum = None
logger = init_logger(__name__)
@ -60,8 +59,6 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
@ -116,8 +113,6 @@ class XPUPlatform(Platform):
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.FLASH_ATTN
@classmethod

View File

@ -51,8 +51,6 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,

View File

@ -84,8 +84,6 @@ class FlashAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,

View File

@ -87,8 +87,6 @@ class FlexAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod

View File

@ -24,12 +24,15 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout,
)

View File

@ -6,12 +6,12 @@ from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
logger = init_logger(__name__)
@ -51,7 +51,7 @@ class OffloadingSpec(ABC):
def get_handlers(
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type["AttentionBackend"]],
attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
"""
Get offloading handlers along with their respective src and dst types.

View File

@ -8,6 +8,7 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CompilationMode,
CUDAGraphMode,
@ -157,8 +158,6 @@ class EagleProposer:
)
# Determine allowed attention backends once during initialization.
from vllm.attention.backends.registry import AttentionBackendEnum
self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]

View File

@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
@ -17,9 +17,6 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups(
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1,
) -> None: