mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 11:41:50 +08:00
[Attention] Update attention imports (#29540)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
cd007a53b4
commit
fc1d8be3dc
@ -139,14 +139,13 @@ def test_standard_attention_backend_selection(
|
||||
import importlib
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
importlib.reload(envs)
|
||||
|
||||
# Convert string backend to enum if provided
|
||||
backend_enum = None
|
||||
if selected_backend:
|
||||
backend_enum = getattr(_Backend, selected_backend)
|
||||
backend_enum = getattr(AttentionBackendEnum, selected_backend)
|
||||
|
||||
# Get the backend class path
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
@ -253,7 +252,6 @@ def test_mla_backend_selection(
|
||||
import importlib
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
importlib.reload(envs)
|
||||
|
||||
@ -269,7 +267,7 @@ def test_mla_backend_selection(
|
||||
# Convert string backend to enum if provided
|
||||
backend_enum = None
|
||||
if selected_backend:
|
||||
backend_enum = getattr(_Backend, selected_backend)
|
||||
backend_enum = getattr(AttentionBackendEnum, selected_backend)
|
||||
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
@ -301,7 +299,6 @@ def test_mla_backend_selection(
|
||||
|
||||
def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
||||
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
# Mock on_gfx9 to return False
|
||||
@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
||||
),
|
||||
):
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=_Backend.ROCM_AITER_FA,
|
||||
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
|
||||
@ -14,6 +14,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from .utils import create_scheduler, create_vllm_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
|
||||
@ -178,8 +177,6 @@ class AttentionBackend(ABC):
|
||||
By default, only supports decoder attention.
|
||||
Backends should override this to support other attention types.
|
||||
"""
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
|
||||
return attn_type == AttentionType.DECODER
|
||||
|
||||
@classmethod
|
||||
@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
kv_b_proj: "ColumnParallelLinear",
|
||||
indexer: object | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -5,6 +5,7 @@ import functools
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheSpec,
|
||||
)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
|
||||
@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
@ -53,7 +54,6 @@ if TYPE_CHECKING:
|
||||
|
||||
import vllm.model_executor.layers.quantization as me_quant
|
||||
import vllm.model_executor.models as me_models
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -61,7 +61,6 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
PretrainedConfig = Any
|
||||
|
||||
AttentionBackendEnum = Any
|
||||
me_quant = LazyLoader(
|
||||
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
||||
)
|
||||
|
||||
@ -2,19 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
else:
|
||||
AttentionBackendEnum = Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDummyOptions:
|
||||
@ -170,9 +166,6 @@ class MultiModalConfig:
|
||||
def _validate_mm_encoder_attn_backend(
|
||||
cls, value: str | AttentionBackendEnum | None
|
||||
) -> AttentionBackendEnum | None:
|
||||
# We need to import the real type here (deferred to avoid circular import).
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if isinstance(value, str) and value.upper() == "XFORMERS":
|
||||
raise ValueError(
|
||||
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
||||
|
||||
@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
|
||||
return
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
"""
|
||||
Initialize with a single KV cache tensor used by all layers.
|
||||
|
||||
@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
@ -45,7 +46,6 @@ from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
|
||||
@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
@ -17,7 +18,6 @@ from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
|
||||
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
|
||||
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl:
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
|
||||
@ -10,6 +10,7 @@ import zmq
|
||||
from lmcache.integration.vllm.utils import mla_enabled
|
||||
from lmcache.utils import init_logger as lmcache_init_logger
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for c in self._connectors:
|
||||
|
||||
@ -20,7 +20,7 @@ import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
@ -51,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
@ -308,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""NixlConnector does not save explicitly."""
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
@ -19,7 +20,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
|
||||
@ -5,19 +5,17 @@ import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||
@ -195,7 +193,7 @@ class ForwardContext:
|
||||
for each microbatch.
|
||||
Set dynamically for each forward pass
|
||||
"""
|
||||
attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]]
|
||||
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
|
||||
@ -3,14 +3,11 @@
|
||||
"""Base class for attention-like layers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class AttentionLayerBase(ABC):
|
||||
"""
|
||||
@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
"""Get the attention backend class for this layer."""
|
||||
pass
|
||||
|
||||
|
||||
@ -2,18 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import get_mamba_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class MambaBase(AttentionLayerBase):
|
||||
"""
|
||||
@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
|
||||
),
|
||||
)
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
return get_mamba_attn_backend(self.mamba_type)
|
||||
|
||||
@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
|
||||
from compressed_tensors.transform import TransformConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
# collect schemes
|
||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
|
||||
@ -14,6 +14,7 @@ import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig):
|
||||
def get_xpu_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||
XPUFp8LinearMethod,
|
||||
XPUFp8MoEMethod,
|
||||
@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if current_platform.is_xpu():
|
||||
return self.get_xpu_quant_method(layer, prefix)
|
||||
if isinstance(layer, LinearBase):
|
||||
|
||||
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# handle kv-cache first so we can focus only on weight quantization thereafter
|
||||
if isinstance(layer, Attention):
|
||||
return self.KVCacheMethodCls(self)
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
@ -184,8 +185,6 @@ class Mxfp4Config(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
|
||||
@ -8,6 +8,7 @@ import regex as re
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
exclude = self.require_exclude_modules()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(
|
||||
|
||||
@ -14,6 +14,7 @@ import regex as re
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
AttentionBackendEnum = None
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
@ -135,8 +134,6 @@ class CpuPlatform(Platform):
|
||||
use_sparse: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
|
||||
@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
else:
|
||||
AttentionBackendEnum = None
|
||||
VllmConfig = None
|
||||
CacheDType = None
|
||||
|
||||
@ -48,8 +48,6 @@ def _get_backend_priorities(
|
||||
device_capability: DeviceCapability,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_mla:
|
||||
if device_capability.major == 10:
|
||||
return [
|
||||
@ -265,8 +263,6 @@ class CudaPlatformBase(Platform):
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
# Try FlashAttention first
|
||||
try:
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
@ -335,8 +331,6 @@ class CudaPlatformBase(Platform):
|
||||
use_sparse: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
|
||||
if attn_type is None:
|
||||
attn_type = AttentionType.DECODER
|
||||
|
||||
|
||||
@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
@ -226,9 +226,6 @@ class Platform:
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
# Import AttentionBackendEnum here to avoid circular import.
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -196,7 +194,6 @@ class RocmPlatform(Platform):
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
@ -222,7 +219,6 @@ class RocmPlatform(Platform):
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
|
||||
import torch
|
||||
from tpu_info import device
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import BlockSize
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -26,7 +26,6 @@ else:
|
||||
BlockSize = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
AttentionBackendEnum = None
|
||||
ParamsType = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -67,8 +66,6 @@ class TpuPlatform(Platform):
|
||||
use_sparse,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
||||
|
||||
@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -60,8 +59,6 @@ class XPUPlatform(Platform):
|
||||
"only NHD layout is supported by XPU attention kernels."
|
||||
)
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
@ -116,8 +113,6 @@ class XPUPlatform(Platform):
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -51,8 +51,6 @@ class CPUAttentionBackend(AttentionBackend):
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""CPU attention supports decoder and encoder-only attention."""
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
|
||||
@ -84,8 +84,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""FlashAttention supports all attention types."""
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
|
||||
@ -87,8 +87,6 @@ class FlexAttentionBackend(AttentionBackend):
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""FlexAttention supports both decoder and encoder-only attention."""
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
|
||||
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -24,12 +24,15 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionImpl
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout,
|
||||
)
|
||||
|
||||
@ -6,12 +6,12 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -51,7 +51,7 @@ class OffloadingSpec(ABC):
|
||||
def get_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type["AttentionBackend"]],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
"""
|
||||
Get offloading handlers along with their respective src and dst types.
|
||||
|
||||
@ -8,6 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
@ -157,8 +158,6 @@ class EagleProposer:
|
||||
)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
self.allowed_attn_types: tuple | None = None
|
||||
if current_platform.is_rocm():
|
||||
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||
|
||||
@ -2,11 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
@ -17,9 +17,6 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups(
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
forward_context: dict[str, Attention],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: int = 1,
|
||||
) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user