mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 04:02:18 +08:00
[Attention] Update attention imports (#29540)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
cd007a53b4
commit
fc1d8be3dc
@ -139,14 +139,13 @@ def test_standard_attention_backend_selection(
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.registry import _Backend
|
|
||||||
|
|
||||||
importlib.reload(envs)
|
importlib.reload(envs)
|
||||||
|
|
||||||
# Convert string backend to enum if provided
|
# Convert string backend to enum if provided
|
||||||
backend_enum = None
|
backend_enum = None
|
||||||
if selected_backend:
|
if selected_backend:
|
||||||
backend_enum = getattr(_Backend, selected_backend)
|
backend_enum = getattr(AttentionBackendEnum, selected_backend)
|
||||||
|
|
||||||
# Get the backend class path
|
# Get the backend class path
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
@ -253,7 +252,6 @@ def test_mla_backend_selection(
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.registry import _Backend
|
|
||||||
|
|
||||||
importlib.reload(envs)
|
importlib.reload(envs)
|
||||||
|
|
||||||
@ -269,7 +267,7 @@ def test_mla_backend_selection(
|
|||||||
# Convert string backend to enum if provided
|
# Convert string backend to enum if provided
|
||||||
backend_enum = None
|
backend_enum = None
|
||||||
if selected_backend:
|
if selected_backend:
|
||||||
backend_enum = getattr(_Backend, selected_backend)
|
backend_enum = getattr(AttentionBackendEnum, selected_backend)
|
||||||
|
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
@ -301,7 +299,6 @@ def test_mla_backend_selection(
|
|||||||
|
|
||||||
def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
||||||
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
|
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
|
||||||
from vllm.attention.backends.registry import _Backend
|
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
# Mock on_gfx9 to return False
|
# Mock on_gfx9 to return False
|
||||||
@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
RocmPlatform.get_attn_backend_cls(
|
RocmPlatform.get_attn_backend_cls(
|
||||||
selected_backend=_Backend.ROCM_AITER_FA,
|
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
head_size=128,
|
head_size=128,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
kv_cache_dtype="auto",
|
kv_cache_dtype="auto",
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from .utils import create_scheduler, create_vllm_config
|
from .utils import create_scheduler, create_vllm_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer,
|
kv_layer,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer,
|
kv_layer,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||||
|
|
||||||
@ -178,8 +177,6 @@ class AttentionBackend(ABC):
|
|||||||
By default, only supports decoder attention.
|
By default, only supports decoder attention.
|
||||||
Backends should override this to support other attention types.
|
Backends should override this to support other attention types.
|
||||||
"""
|
"""
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
|
||||||
|
|
||||||
return attn_type == AttentionType.DECODER
|
return attn_type == AttentionType.DECODER
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
||||||
"""
|
"""
|
||||||
Does this attention implementation support fused output quantization.
|
Does this attention implementation support fused output quantization.
|
||||||
This is used by the AttnFusionPass to only fuse output quantization
|
This is used by the AttnFusionPass to only fuse output quantization
|
||||||
@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
|||||||
qk_rope_head_dim: int,
|
qk_rope_head_dim: int,
|
||||||
qk_head_dim: int,
|
qk_head_dim: int,
|
||||||
v_head_dim: int,
|
v_head_dim: int,
|
||||||
kv_b_proj: ColumnParallelLinear,
|
kv_b_proj: "ColumnParallelLinear",
|
||||||
indexer: object | None = None,
|
indexer: object | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import functools
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.config.vllm import VllmConfig
|
from vllm.config.vllm import VllmConfig
|
||||||
@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
KVCacheSpec,
|
KVCacheSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..layer import Attention
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def create_chunked_local_attention_backend(
|
def create_chunked_local_attention_backend(
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
|||||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
|
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
|
||||||
from vllm.config.pooler import PoolerConfig
|
from vllm.config.pooler import PoolerConfig
|
||||||
from vllm.config.scheduler import RunnerType
|
from vllm.config.scheduler import RunnerType
|
||||||
@ -53,7 +54,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
import vllm.model_executor.layers.quantization as me_quant
|
import vllm.model_executor.layers.quantization as me_quant
|
||||||
import vllm.model_executor.models as me_models
|
import vllm.model_executor.models as me_models
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
from vllm.config.parallel import ParallelConfig
|
from vllm.config.parallel import ParallelConfig
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -61,7 +61,6 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
PretrainedConfig = Any
|
PretrainedConfig = Any
|
||||||
|
|
||||||
AttentionBackendEnum = Any
|
|
||||||
me_quant = LazyLoader(
|
me_quant = LazyLoader(
|
||||||
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,19 +2,15 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
from typing import Any, Literal, TypeAlias
|
||||||
|
|
||||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.utils.hashing import safe_hash
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
else:
|
|
||||||
AttentionBackendEnum = Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseDummyOptions:
|
class BaseDummyOptions:
|
||||||
@ -170,9 +166,6 @@ class MultiModalConfig:
|
|||||||
def _validate_mm_encoder_attn_backend(
|
def _validate_mm_encoder_attn_backend(
|
||||||
cls, value: str | AttentionBackendEnum | None
|
cls, value: str | AttentionBackendEnum | None
|
||||||
) -> AttentionBackendEnum | None:
|
) -> AttentionBackendEnum | None:
|
||||||
# We need to import the real type here (deferred to avoid circular import).
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
if isinstance(value, str) and value.upper() == "XFORMERS":
|
if isinstance(value, str) and value.upper() == "XFORMERS":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
||||||
|
|||||||
@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.outputs import KVConnectorOutput
|
from vllm.v1.outputs import KVConnectorOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_events import KVCacheEvent
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def register_cross_layers_kv_cache(
|
def register_cross_layers_kv_cache(
|
||||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with a single KV cache tensor used by all layers.
|
Initialize with a single KV cache tensor used by all layers.
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
KVConnectorRole,
|
KVConnectorRole,
|
||||||
@ -45,7 +46,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer: torch.Tensor,
|
kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
# This connector doesn't save KV cache (benchmarking only)
|
# This connector doesn't save KV cache (benchmarking only)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import (
|
|||||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
@ -17,7 +18,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer: torch.Tensor,
|
kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
|
|||||||
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
|
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
|
||||||
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
|
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||||
@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl:
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer: torch.Tensor,
|
kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start saving the a layer of KV cache from vLLM's paged buffer
|
"""Start saving the a layer of KV cache from vLLM's paged buffer
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import zmq
|
|||||||
from lmcache.integration.vllm.utils import mla_enabled
|
from lmcache.integration.vllm.utils import mla_enabled
|
||||||
from lmcache.utils import init_logger as lmcache_init_logger
|
from lmcache.utils import init_logger as lmcache_init_logger
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput
|
|||||||
from vllm.v1.utils import ConstantList
|
from vllm.v1.utils import ConstantList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_events import KVCacheEvent
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer: torch.Tensor,
|
kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.kv_transfer import KVTransferConfig
|
from vllm.config.kv_transfer import KVTransferConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||||
@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.outputs import KVConnectorOutput
|
from vllm.v1.outputs import KVConnectorOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.distributed.kv_events import KVCacheEvent
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer: torch.Tensor,
|
kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
for c in self._connectors:
|
for c in self._connectors:
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import torch
|
|||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -51,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
@ -308,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer: torch.Tensor,
|
kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""NixlConnector does not save explicitly."""
|
"""NixlConnector does not save explicitly."""
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer: torch.Tensor,
|
kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
@ -19,7 +20,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_layer: torch.Tensor,
|
kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata",
|
attn_metadata: AttentionMetadata,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||||
|
|||||||
@ -5,19 +5,17 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||||
from vllm.v1.worker.ubatch_utils import UBatchSlices
|
from vllm.v1.worker.ubatch_utils import UBatchSlices
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||||
@ -195,7 +193,7 @@ class ForwardContext:
|
|||||||
for each microbatch.
|
for each microbatch.
|
||||||
Set dynamically for each forward pass
|
Set dynamically for each forward pass
|
||||||
"""
|
"""
|
||||||
attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]]
|
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
|
||||||
# TODO: remove after making all virtual_engines share the same kv cache
|
# TODO: remove after making all virtual_engines share the same kv cache
|
||||||
virtual_engine: int # set dynamically for each forward pass
|
virtual_engine: int # set dynamically for each forward pass
|
||||||
# set dynamically for each forward pass
|
# set dynamically for each forward pass
|
||||||
|
|||||||
@ -3,14 +3,11 @@
|
|||||||
"""Base class for attention-like layers."""
|
"""Base class for attention-like layers."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionLayerBase(ABC):
|
class AttentionLayerBase(ABC):
|
||||||
"""
|
"""
|
||||||
@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
"""Get the attention backend class for this layer."""
|
"""Get the attention backend class for this layer."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -2,18 +2,15 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.selector import get_mamba_attn_backend
|
from vllm.attention.selector import get_mamba_attn_backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
|
||||||
|
|
||||||
|
|
||||||
class MambaBase(AttentionLayerBase):
|
class MambaBase(AttentionLayerBase):
|
||||||
"""
|
"""
|
||||||
@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
"""Get the attention backend class for this Mamba layer."""
|
"""Get the attention backend class for this Mamba layer."""
|
||||||
return get_mamba_attn_backend(self.mamba_type)
|
return get_mamba_attn_backend(self.mamba_type)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
|
|||||||
from compressed_tensors.transform import TransformConfig
|
from compressed_tensors.transform import TransformConfig
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
# collect schemes
|
# collect schemes
|
||||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import vllm.envs as envs
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig):
|
|||||||
def get_xpu_quant_method(
|
def get_xpu_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||||
XPUFp8LinearMethod,
|
XPUFp8LinearMethod,
|
||||||
XPUFp8MoEMethod,
|
XPUFp8MoEMethod,
|
||||||
@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
|
||||||
|
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
return self.get_xpu_quant_method(layer, prefix)
|
return self.get_xpu_quant_method(layer, prefix)
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
|
||||||
|
|
||||||
# handle kv-cache first so we can focus only on weight quantization thereafter
|
# handle kv-cache first so we can focus only on weight quantization thereafter
|
||||||
if isinstance(layer, Attention):
|
if isinstance(layer, Attention):
|
||||||
return self.KVCacheMethodCls(self)
|
return self.KVCacheMethodCls(self)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
@ -184,8 +185,6 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if self.ignored_layers and is_layer_skipped(
|
if self.ignored_layers and is_layer_skipped(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import regex as re
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
LinearBase,
|
LinearBase,
|
||||||
@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
|
||||||
|
|
||||||
exclude = self.require_exclude_modules()
|
exclude = self.require_exclude_modules()
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignored_layers):
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
|
||||||
|
|
||||||
# Check if the layer is skipped for quantization.
|
# Check if the layer is skipped for quantization.
|
||||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||||
if should_ignore_layer(
|
if should_ignore_layer(
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import regex as re
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||||
@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
else:
|
else:
|
||||||
AttentionBackendEnum = None
|
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
|
|
||||||
|
|
||||||
@ -135,8 +134,6 @@ class CpuPlatform(Platform):
|
|||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
attn_type: str | None = None,
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||||
if use_mla:
|
if use_mla:
|
||||||
|
|||||||
@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
|
|||||||
# import custom ops, trigger op registration
|
# import custom ops, trigger op registration
|
||||||
import vllm._C # noqa
|
import vllm._C # noqa
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.import_utils import import_pynvml
|
from vllm.utils.import_utils import import_pynvml
|
||||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||||
@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
|||||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
else:
|
else:
|
||||||
AttentionBackendEnum = None
|
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
CacheDType = None
|
CacheDType = None
|
||||||
|
|
||||||
@ -48,8 +48,6 @@ def _get_backend_priorities(
|
|||||||
device_capability: DeviceCapability,
|
device_capability: DeviceCapability,
|
||||||
) -> list[AttentionBackendEnum]:
|
) -> list[AttentionBackendEnum]:
|
||||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
if use_mla:
|
if use_mla:
|
||||||
if device_capability.major == 10:
|
if device_capability.major == 10:
|
||||||
return [
|
return [
|
||||||
@ -265,8 +263,6 @@ class CudaPlatformBase(Platform):
|
|||||||
def get_vit_attn_backend(
|
def get_vit_attn_backend(
|
||||||
cls, head_size: int, dtype: torch.dtype
|
cls, head_size: int, dtype: torch.dtype
|
||||||
) -> "AttentionBackendEnum":
|
) -> "AttentionBackendEnum":
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
# Try FlashAttention first
|
# Try FlashAttention first
|
||||||
try:
|
try:
|
||||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||||
@ -335,8 +331,6 @@ class CudaPlatformBase(Platform):
|
|||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
attn_type: str | None = None,
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
|
||||||
|
|
||||||
if attn_type is None:
|
if attn_type is None:
|
||||||
attn_type = AttentionType.DECODER
|
attn_type = AttentionType.DECODER
|
||||||
|
|
||||||
|
|||||||
@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.distributed import PrefixStore, ProcessGroup
|
from torch.distributed import PrefixStore, ProcessGroup
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
from vllm.inputs import ProcessorInputs, PromptType
|
from vllm.inputs import ProcessorInputs, PromptType
|
||||||
@ -226,9 +226,6 @@ class Platform:
|
|||||||
def get_vit_attn_backend(
|
def get_vit_attn_backend(
|
||||||
cls, head_size: int, dtype: torch.dtype
|
cls, head_size: int, dtype: torch.dtype
|
||||||
) -> "AttentionBackendEnum":
|
) -> "AttentionBackendEnum":
|
||||||
# Import AttentionBackendEnum here to avoid circular import.
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
return AttentionBackendEnum.TORCH_SDPA
|
return AttentionBackendEnum.TORCH_SDPA
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
else:
|
|
||||||
AttentionBackendEnum = None
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -196,7 +194,6 @@ class RocmPlatform(Platform):
|
|||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
|
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
if rocm_aiter_ops.is_mha_enabled():
|
if rocm_aiter_ops.is_mha_enabled():
|
||||||
# Note: AITER FA is only supported for Qwen-VL models.
|
# Note: AITER FA is only supported for Qwen-VL models.
|
||||||
@ -222,7 +219,6 @@ class RocmPlatform(Platform):
|
|||||||
attn_type: str | None = None,
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
if use_sparse:
|
if use_sparse:
|
||||||
if kv_cache_dtype.startswith("fp8"):
|
if kv_cache_dtype.startswith("fp8"):
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
|
|||||||
import torch
|
import torch
|
||||||
from tpu_info import device
|
from tpu_info import device
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.inputs import ProcessorInputs, PromptType
|
from vllm.inputs import ProcessorInputs, PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import TypeAlias
|
from typing import TypeAlias
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.cache import BlockSize
|
from vllm.config.cache import BlockSize
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
@ -26,7 +26,6 @@ else:
|
|||||||
BlockSize = None
|
BlockSize = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
PoolingParams = None
|
PoolingParams = None
|
||||||
AttentionBackendEnum = None
|
|
||||||
ParamsType = None
|
ParamsType = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -67,8 +66,6 @@ class TpuPlatform(Platform):
|
|||||||
use_sparse,
|
use_sparse,
|
||||||
attn_type: str | None = None,
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
if use_sparse:
|
if use_sparse:
|
||||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
||||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
if selected_backend != AttentionBackendEnum.PALLAS:
|
||||||
|
|||||||
@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
else:
|
else:
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
AttentionBackendEnum = None
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -60,8 +59,6 @@ class XPUPlatform(Platform):
|
|||||||
"only NHD layout is supported by XPU attention kernels."
|
"only NHD layout is supported by XPU attention kernels."
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
if use_sparse:
|
if use_sparse:
|
||||||
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
||||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||||
@ -116,8 +113,6 @@ class XPUPlatform(Platform):
|
|||||||
def get_vit_attn_backend(
|
def get_vit_attn_backend(
|
||||||
cls, head_size: int, dtype: torch.dtype
|
cls, head_size: int, dtype: torch.dtype
|
||||||
) -> "AttentionBackendEnum":
|
) -> "AttentionBackendEnum":
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
return AttentionBackendEnum.FLASH_ATTN
|
return AttentionBackendEnum.FLASH_ATTN
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -51,8 +51,6 @@ class CPUAttentionBackend(AttentionBackend):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||||
"""CPU attention supports decoder and encoder-only attention."""
|
"""CPU attention supports decoder and encoder-only attention."""
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
|
||||||
|
|
||||||
return attn_type in (
|
return attn_type in (
|
||||||
AttentionType.DECODER,
|
AttentionType.DECODER,
|
||||||
AttentionType.ENCODER,
|
AttentionType.ENCODER,
|
||||||
|
|||||||
@ -84,8 +84,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||||
"""FlashAttention supports all attention types."""
|
"""FlashAttention supports all attention types."""
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
|
||||||
|
|
||||||
return attn_type in (
|
return attn_type in (
|
||||||
AttentionType.DECODER,
|
AttentionType.DECODER,
|
||||||
AttentionType.ENCODER,
|
AttentionType.ENCODER,
|
||||||
|
|||||||
@ -87,8 +87,6 @@ class FlexAttentionBackend(AttentionBackend):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||||
"""FlexAttention supports both decoder and encoder-only attention."""
|
"""FlexAttention supports both decoder and encoder-only attention."""
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
|
||||||
|
|
||||||
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
|
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -24,12 +24,15 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionImpl
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionImpl,
|
||||||
|
AttentionMetadata,
|
||||||
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
get_kv_connector_cache_layout,
|
get_kv_connector_cache_layout,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,12 +6,12 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -51,7 +51,7 @@ class OffloadingSpec(ABC):
|
|||||||
def get_handlers(
|
def get_handlers(
|
||||||
self,
|
self,
|
||||||
kv_caches: dict[str, torch.Tensor],
|
kv_caches: dict[str, torch.Tensor],
|
||||||
attn_backends: dict[str, type["AttentionBackend"]],
|
attn_backends: dict[str, type[AttentionBackend]],
|
||||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||||
"""
|
"""
|
||||||
Get offloading handlers along with their respective src and dst types.
|
Get offloading handlers along with their respective src and dst types.
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CompilationMode,
|
CompilationMode,
|
||||||
CUDAGraphMode,
|
CUDAGraphMode,
|
||||||
@ -157,8 +158,6 @@ class EagleProposer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Determine allowed attention backends once during initialization.
|
# Determine allowed attention backends once during initialization.
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
||||||
|
|
||||||
self.allowed_attn_types: tuple | None = None
|
self.allowed_attn_types: tuple | None = None
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||||
|
|||||||
@ -2,11 +2,11 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
@ -17,9 +17,6 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
|||||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalBudget:
|
class MultiModalBudget:
|
||||||
"""Helper class to calculate budget information for multi-modal models."""
|
"""Helper class to calculate budget information for multi-modal models."""
|
||||||
@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups(
|
|||||||
|
|
||||||
def bind_kv_cache(
|
def bind_kv_cache(
|
||||||
kv_caches: dict[str, torch.Tensor],
|
kv_caches: dict[str, torch.Tensor],
|
||||||
forward_context: dict[str, "Attention"],
|
forward_context: dict[str, Attention],
|
||||||
runner_kv_caches: list[torch.Tensor],
|
runner_kv_caches: list[torch.Tensor],
|
||||||
num_attn_module: int = 1,
|
num_attn_module: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user