[Attention] Update attention imports (#29540)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-11-27 11:19:09 -05:00 committed by GitHub
parent cd007a53b4
commit fc1d8be3dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 63 additions and 126 deletions

View File

@ -139,14 +139,13 @@ def test_standard_attention_backend_selection(
import importlib import importlib
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
importlib.reload(envs) importlib.reload(envs)
# Convert string backend to enum if provided # Convert string backend to enum if provided
backend_enum = None backend_enum = None
if selected_backend: if selected_backend:
backend_enum = getattr(_Backend, selected_backend) backend_enum = getattr(AttentionBackendEnum, selected_backend)
# Get the backend class path # Get the backend class path
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
@ -253,7 +252,6 @@ def test_mla_backend_selection(
import importlib import importlib
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
importlib.reload(envs) importlib.reload(envs)
@ -269,7 +267,7 @@ def test_mla_backend_selection(
# Convert string backend to enum if provided # Convert string backend to enum if provided
backend_enum = None backend_enum = None
if selected_backend: if selected_backend:
backend_enum = getattr(_Backend, selected_backend) backend_enum = getattr(AttentionBackendEnum, selected_backend)
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
@ -301,7 +299,6 @@ def test_mla_backend_selection(
def test_aiter_fa_requires_gfx9(mock_vllm_config): def test_aiter_fa_requires_gfx9(mock_vllm_config):
"""Test that ROCM_AITER_FA requires gfx9 architecture.""" """Test that ROCM_AITER_FA requires gfx9 architecture."""
from vllm.attention.backends.registry import _Backend
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
# Mock on_gfx9 to return False # Mock on_gfx9 to return False
@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
), ),
): ):
RocmPlatform.get_attn_backend_cls( RocmPlatform.get_attn_backend_cls(
selected_backend=_Backend.ROCM_AITER_FA, selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="auto", kv_cache_dtype="auto",

View File

@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest import pytest
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer, kv_layer,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
pass pass
@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer, kv_layer,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
pass pass

View File

@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
@ -178,8 +177,6 @@ class AttentionBackend(ABC):
By default, only supports decoder attention. By default, only supports decoder attention.
Backends should override this to support other attention types. Backends should override this to support other attention types.
""" """
from vllm.attention.backends.abstract import AttentionType
return attn_type == AttentionType.DECODER return attn_type == AttentionType.DECODER
@classmethod @classmethod
@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: "QuantKey"):
""" """
Does this attention implementation support fused output quantization. Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization This is used by the AttnFusionPass to only fuse output quantization
@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
qk_rope_head_dim: int, qk_rope_head_dim: int,
qk_head_dim: int, qk_head_dim: int,
v_head_dim: int, v_head_dim: int,
kv_b_proj: ColumnParallelLinear, kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None, indexer: object | None = None,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -5,6 +5,7 @@ import functools
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec, KVCacheSpec,
) )
from ..layer import Attention
@functools.lru_cache @functools.lru_cache
def create_chunked_local_attention_backend( def create_chunked_local_attention_backend(

View File

@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES from transformers.configuration_utils import ALLOWED_LAYER_TYPES
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType from vllm.config.scheduler import RunnerType
@ -53,7 +54,6 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -61,7 +61,6 @@ if TYPE_CHECKING:
else: else:
PretrainedConfig = Any PretrainedConfig = Any
AttentionBackendEnum = Any
me_quant = LazyLoader( me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization" "model_executor", globals(), "vllm.model_executor.layers.quantization"
) )

View File

@ -2,19 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeAlias from typing import Any, Literal, TypeAlias
from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
else:
AttentionBackendEnum = Any
@dataclass @dataclass
class BaseDummyOptions: class BaseDummyOptions:
@ -170,9 +166,6 @@ class MultiModalConfig:
def _validate_mm_encoder_attn_backend( def _validate_mm_encoder_attn_backend(
cls, value: str | AttentionBackendEnum | None cls, value: str | AttentionBackendEnum | None
) -> AttentionBackendEnum | None: ) -> AttentionBackendEnum | None:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum
if isinstance(value, str) and value.upper() == "XFORMERS": if isinstance(value, str) and value.upper() == "XFORMERS":
raise ValueError( raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for " "Attention backend 'XFORMERS' has been removed (See PR #29262 for "

View File

@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
return return
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
): ):
""" """
Initialize with a single KV cache tensor used by all layers. Initialize with a single KV cache tensor used by all layers.

View File

@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole,
@ -45,7 +46,6 @@ from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
# This connector doesn't save KV cache (benchmarking only) # This connector doesn't save KV cache (benchmarking only)

View File

@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl, LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
) )
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
@ -17,7 +18,6 @@ from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """

View File

@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
from lmcache.v1.plugin.plugin_launcher import PluginLauncher from lmcache.v1.plugin.plugin_launcher import PluginLauncher
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl:
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Start saving the a layer of KV cache from vLLM's paged buffer """Start saving the a layer of KV cache from vLLM's paged buffer

View File

@ -10,6 +10,7 @@ import zmq
from lmcache.integration.vllm.utils import mla_enabled from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger from lmcache.utils import init_logger as lmcache_init_logger
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
for c in self._connectors: for c in self._connectors:

View File

@ -20,7 +20,7 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -51,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
@ -308,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
"""NixlConnector does not save explicitly.""" """NixlConnector does not save explicitly."""

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer """Start saving the KV cache of the layer from vLLM's paged buffer

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import safetensors import safetensors
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
@ -19,7 +20,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer """Start saving the KV cache of the layer from vLLM's paged buffer

View File

@ -5,19 +5,17 @@ import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple from typing import Any, NamedTuple
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices from vllm.v1.worker.ubatch_utils import UBatchSlices
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
@ -195,7 +193,7 @@ class ForwardContext:
for each microbatch. for each microbatch.
Set dynamically for each forward pass Set dynamically for each forward pass
""" """
attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]] attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
# TODO: remove after making all virtual_engines share the same kv cache # TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass # set dynamically for each forward pass

View File

@ -3,14 +3,11 @@
"""Base class for attention-like layers.""" """Base class for attention-like layers."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC): class AttentionLayerBase(ABC):
""" """
@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
""" """
@abstractmethod @abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]: def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer.""" """Get the attention backend class for this layer."""
pass pass

View File

@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_mamba_attn_backend from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase): class MambaBase(AttentionLayerBase):
""" """
@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
), ),
) )
def get_attn_backend(self) -> type["AttentionBackend"]: def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this Mamba layer.""" """Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type) return get_mamba_attn_backend(self.mamba_type)

View File

@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
from compressed_tensors.transform import TransformConfig from compressed_tensors.transform import TransformConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
# collect schemes # collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)

View File

@ -14,6 +14,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.layer import Attention
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig):
def get_xpu_quant_method( def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
from vllm.model_executor.layers.quantization.ipex_quant import ( from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod, XPUFp8LinearMethod,
XPUFp8MoEMethod, XPUFp8MoEMethod,
@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if current_platform.is_xpu(): if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix) return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):

View File

@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# handle kv-cache first so we can focus only on weight quantization thereafter # handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention): if isinstance(layer, Attention):
return self.KVCacheMethodCls(self) return self.KVCacheMethodCls(self)

View File

@ -8,6 +8,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import envs from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
@ -184,8 +185,6 @@ class Mxfp4Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped( if self.ignored_layers and is_layer_skipped(
prefix=prefix, prefix=prefix,

View File

@ -8,6 +8,7 @@ import regex as re
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
exclude = self.require_exclude_modules() exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):

View File

@ -7,6 +7,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import torch import torch
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude")) exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer( if should_ignore_layer(

View File

@ -14,6 +14,7 @@ import regex as re
import torch import torch
from vllm import envs from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import CpuArchEnum, Platform, PlatformEnum from .interface import CpuArchEnum, Platform, PlatformEnum
@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
AttentionBackendEnum = None
VllmConfig = None VllmConfig = None
@ -135,8 +134,6 @@ class CpuPlatform(Platform):
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:

View File

@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
else: else:
AttentionBackendEnum = None
VllmConfig = None VllmConfig = None
CacheDType = None CacheDType = None
@ -48,8 +48,6 @@ def _get_backend_priorities(
device_capability: DeviceCapability, device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]: ) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency.""" """Get backend priorities with lazy import to avoid circular dependency."""
from vllm.attention.backends.registry import AttentionBackendEnum
if use_mla: if use_mla:
if device_capability.major == 10: if device_capability.major == 10:
return [ return [
@ -265,8 +263,6 @@ class CudaPlatformBase(Platform):
def get_vit_attn_backend( def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
# Try FlashAttention first # Try FlashAttention first
try: try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
@ -335,8 +331,6 @@ class CudaPlatformBase(Platform):
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.abstract import AttentionType
if attn_type is None: if attn_type is None:
attn_type = AttentionType.DECODER attn_type = AttentionType.DECODER

View File

@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np import numpy as np
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
@ -226,9 +226,6 @@ class Platform:
def get_vit_attn_backend( def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
# Import AttentionBackendEnum here to avoid circular import.
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod

View File

@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else:
AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -196,7 +194,6 @@ class RocmPlatform(Platform):
from importlib.util import find_spec from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled(): if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models. # Note: AITER FA is only supported for Qwen-VL models.
@ -222,7 +219,6 @@ class RocmPlatform(Platform):
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
if kv_cache_dtype.startswith("fp8"): if kv_cache_dtype.startswith("fp8"):

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
import torch import torch
from tpu_info import device from tpu_info import device
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeAlias from typing import TypeAlias
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import BlockSize from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -26,7 +26,6 @@ else:
BlockSize = None BlockSize = None
VllmConfig = None VllmConfig = None
PoolingParams = None PoolingParams = None
AttentionBackendEnum = None
ParamsType = None ParamsType = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -67,8 +66,6 @@ class TpuPlatform(Platform):
use_sparse, use_sparse,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.") raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS: if selected_backend != AttentionBackendEnum.PALLAS:

View File

@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
VllmConfig = None VllmConfig = None
AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -60,8 +59,6 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels." "only NHD layout is supported by XPU attention kernels."
) )
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.") raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN: if selected_backend == AttentionBackendEnum.TRITON_ATTN:
@ -116,8 +113,6 @@ class XPUPlatform(Platform):
def get_vit_attn_backend( def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN
@classmethod @classmethod

View File

@ -51,8 +51,6 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention.""" """CPU attention supports decoder and encoder-only attention."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in ( return attn_type in (
AttentionType.DECODER, AttentionType.DECODER,
AttentionType.ENCODER, AttentionType.ENCODER,

View File

@ -84,8 +84,6 @@ class FlashAttentionBackend(AttentionBackend):
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types.""" """FlashAttention supports all attention types."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in ( return attn_type in (
AttentionType.DECODER, AttentionType.DECODER,
AttentionType.ENCODER, AttentionType.ENCODER,

View File

@ -87,8 +87,6 @@ class FlexAttentionBackend(AttentionBackend):
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention.""" """FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod @staticmethod

View File

@ -24,12 +24,15 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout, get_kv_connector_cache_layout,
) )

View File

@ -6,12 +6,12 @@ from typing import TYPE_CHECKING
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler from vllm.v1.kv_offload.worker.worker import OffloadingHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -51,7 +51,7 @@ class OffloadingSpec(ABC):
def get_handlers( def get_handlers(
self, self,
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type["AttentionBackend"]], attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
""" """
Get offloading handlers along with their respective src and dst types. Get offloading handlers along with their respective src and dst types.

View File

@ -8,6 +8,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CompilationMode, CompilationMode,
CUDAGraphMode, CUDAGraphMode,
@ -157,8 +158,6 @@ class EagleProposer:
) )
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
from vllm.attention.backends.registry import AttentionBackendEnum
self.allowed_attn_types: tuple | None = None self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm(): if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]

View File

@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
@ -17,9 +17,6 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
class MultiModalBudget: class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models.""" """Helper class to calculate budget information for multi-modal models."""
@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups(
def bind_kv_cache( def bind_kv_cache(
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"], forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor], runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1, num_attn_module: int = 1,
) -> None: ) -> None: