mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 02:47:04 +08:00
[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
7eb719df13
commit
8c1fb50705
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import which_attn_to_use
|
||||
from vllm.platforms import cpu, cuda, openvino, rocm
|
||||
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
|
||||
|
||||
@ -19,26 +20,28 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
override_backend_env_variable(monkeypatch, name)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform.is_cpu",
|
||||
return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
cpu.CpuPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "TORCH_SDPA"
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform.is_rocm",
|
||||
return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
rocm.RocmPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "ROCM_FLASH"
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.current_platform.is_openvino",
|
||||
return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
openvino.OpenVinoPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "OPENVINO"
|
||||
else:
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
cuda.CudaPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == name
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import enum
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
@ -9,26 +8,12 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
OPENVINO = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||
assert backend_name is not None
|
||||
|
||||
@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int,
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if current_platform.is_openvino():
|
||||
if selected_backend != _Backend.OPENVINO:
|
||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||
return _Backend.OPENVINO
|
||||
|
||||
if current_platform.is_xpu():
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
return _Backend.IPEX
|
||||
|
||||
if current_platform.is_tpu():
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# AMD GPUs.
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if not current_platform.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
return _Backend.ROCM_FLASH
|
||||
|
||||
if current_platform.is_hpu():
|
||||
return _Backend.HPU_ATTN
|
||||
# get device-specific default attn_backend
|
||||
default_backend = current_platform.get_default_attn_backend(
|
||||
selected_backend)
|
||||
if default_backend is not None:
|
||||
return default_backend
|
||||
|
||||
if use_v1:
|
||||
return _Backend.FLASH_ATTN_VLLM_V1
|
||||
|
||||
@ -13,7 +13,6 @@ from torch.nn import functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -38,6 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
|
||||
@ -39,7 +39,6 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||
make_batched_images, make_batched_videos, smart_resize)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group, parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -65,6 +64,7 @@ from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
|
||||
@ -9,13 +9,13 @@ from torch.func import functional_call
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.selector import (_Backend, backend_name_to_enum,
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .interface import _Backend # noqa: F401
|
||||
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
||||
|
||||
current_platform: Platform
|
||||
|
||||
@ -5,7 +5,9 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@ -22,6 +24,12 @@ class CpuPlatform(Platform):
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
|
||||
class HpuPlatform(Platform):
|
||||
_enum = PlatformEnum.HPU
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
return _Backend.HPU_ATTN
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
|
||||
@ -11,6 +11,20 @@ else:
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
OPENVINO = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
ROCM = enum.auto()
|
||||
@ -71,6 +85,11 @@ class Platform:
|
||||
"""Stateless version of :func:`torch.cuda.is_available`."""
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend):
|
||||
"""Get the default attention backend of a device."""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
|
||||
@ -3,7 +3,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -11,6 +11,12 @@ logger = init_logger(__name__)
|
||||
class OpenVinoPlatform(Platform):
|
||||
_enum = PlatformEnum.OPENVINO
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.OPENVINO:
|
||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||
return _Backend.OPENVINO
|
||||
|
||||
@classmethod
|
||||
def get_device_name(self, device_id: int = 0) -> str:
|
||||
return "openvino"
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -19,6 +19,18 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if not cls.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
return _Backend.ROCM_FLASH
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
|
||||
@ -3,17 +3,27 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,11 +1,21 @@
|
||||
import torch
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUPlatform(Platform):
|
||||
_enum = PlatformEnum.XPU
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
return _Backend.IPEX
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
||||
major, minor, *_ = torch.xpu.get_device_capability(
|
||||
|
||||
@ -8,7 +8,7 @@ import torch.distributed
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||
from vllm.attention.selector import (get_env_variable_attn_backend,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -18,6 +18,7 @@ from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user