[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2024-11-19 11:22:26 +08:00 committed by GitHub
parent 7eb719df13
commit 8c1fb50705
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 99 additions and 69 deletions

View File

@ -5,6 +5,7 @@ import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.platforms import cpu, cuda, openvino, rocm
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@ -19,26 +20,28 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)
if device == "cpu":
with patch("vllm.attention.selector.current_platform.is_cpu",
return_value=True):
with patch("vllm.attention.selector.current_platform",
cpu.CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
with patch("vllm.attention.selector.current_platform",
rocm.RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform.is_openvino",
return_value=True):
with patch("vllm.attention.selector.current_platform",
openvino.OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
with patch("vllm.attention.selector.current_platform",
cuda.CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name

View File

@ -1,4 +1,3 @@
import enum
import os
from contextlib import contextmanager
from functools import lru_cache
@ -9,26 +8,12 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms import _Backend, current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None
@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int,
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if current_platform.is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
if current_platform.is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
if current_platform.is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if current_platform.is_rocm():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not current_platform.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
if current_platform.is_hpu():
return _Backend.HPU_ATTN
# get device-specific default attn_backend
default_backend = current_platform.get_default_attn_backend(
selected_backend)
if default_backend is not None:
return default_backend
if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1

View File

@ -13,7 +13,6 @@ from torch.nn import functional as F
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -38,6 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor

View File

@ -39,7 +39,6 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize)
from vllm.attention import AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
@ -65,6 +64,7 @@ from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor

View File

@ -9,13 +9,13 @@ from torch.func import functional_call
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum,
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available

View File

@ -1,3 +1,4 @@
from .interface import _Backend # noqa: F401
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform

View File

@ -5,7 +5,9 @@ import torch
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.config import VllmConfig
@ -22,6 +24,12 @@ class CpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total

View File

@ -1,11 +1,15 @@
import torch
from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
return _Backend.HPU_ATTN
@staticmethod
def inference_mode():
return torch.no_grad()

View File

@ -11,6 +11,20 @@ else:
VllmConfig = None
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
@ -71,6 +85,11 @@ class Platform:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend):
"""Get the default attention backend of a device."""
return None
@classmethod
def get_device_capability(
cls,

View File

@ -3,7 +3,7 @@ import torch
import vllm.envs as envs
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
@ -11,6 +11,12 @@ logger = init_logger(__name__)
class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
@classmethod
def get_device_name(self, device_id: int = 0) -> str:
return "openvino"

View File

@ -5,7 +5,7 @@ import torch
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
@ -19,6 +19,18 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:

View File

@ -3,17 +3,27 @@ from typing import TYPE_CHECKING
import torch
from .interface import Platform, PlatformEnum
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError

View File

@ -1,11 +1,21 @@
import torch
from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
@staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability:
major, minor, *_ = torch.xpu.get_device_capability(

View File

@ -8,7 +8,7 @@ import torch.distributed
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
from vllm.attention.selector import (get_env_variable_attn_backend,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
@ -18,6 +18,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.platforms import _Backend
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata)