mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 01:07:09 +08:00
[Platform] Do not raise error if _Backend is not found (#12023)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
ad388d25a8
commit
3adf0ffda8
@ -94,7 +94,12 @@ def test_flash_attn(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
def test_invalid_env(monkeypatch):
|
def test_invalid_env(monkeypatch):
|
||||||
"""Throw an exception if the backend name is invalid."""
|
"""Ignore the invalid env variable if it is set."""
|
||||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||||
with pytest.raises(ValueError):
|
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||||
get_attn_backend(16, torch.float16, None, 16, False)
|
backend = get_attn_backend(32, torch.float16, None, 16, False)
|
||||||
|
assert backend.get_name() == "FLASH_ATTN"
|
||||||
|
|
||||||
|
# when block size == 16, backend will fall back to XFORMERS
|
||||||
|
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||||
|
assert backend.get_name() == "XFORMERS"
|
||||||
|
|||||||
@ -0,0 +1,8 @@
|
|||||||
|
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
||||||
|
|
||||||
|
|
||||||
|
class DummyAttentionBackend(FlashAttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "Dummy_Backend"
|
||||||
@ -3,3 +3,7 @@ from vllm.platforms.cuda import CudaPlatform
|
|||||||
|
|
||||||
class DummyPlatform(CudaPlatform):
|
class DummyPlatform(CudaPlatform):
|
||||||
device_name = "DummyDevice"
|
device_name = "DummyDevice"
|
||||||
|
|
||||||
|
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||||
|
kv_cache_dtype, block_size, use_v1):
|
||||||
|
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||||
|
|||||||
@ -1,3 +1,10 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.utils import STR_INVALID_VAL
|
||||||
|
|
||||||
|
|
||||||
def test_platform_plugins():
|
def test_platform_plugins():
|
||||||
# simulate workload by running an example
|
# simulate workload by running an example
|
||||||
import runpy
|
import runpy
|
||||||
@ -14,3 +21,10 @@ def test_platform_plugins():
|
|||||||
f"Expected DummyDevice, got {current_platform.device_name}, "
|
f"Expected DummyDevice, got {current_platform.device_name}, "
|
||||||
"possibly because current_platform is imported before the plugin"
|
"possibly because current_platform is imported before the plugin"
|
||||||
f" is loaded. The first import:\n{_init_trace}")
|
f" is loaded. The first import:\n{_init_trace}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_oot_attention_backend(monkeypatch):
|
||||||
|
# ignore the backend env variable if it is set
|
||||||
|
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||||
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||||
|
assert backend.get_name() == "Dummy_Backend"
|
||||||
|
|||||||
@ -190,11 +190,11 @@ class MultiHeadAttention(nn.Module):
|
|||||||
kv_cache_dtype=None,
|
kv_cache_dtype=None,
|
||||||
block_size=16,
|
block_size=16,
|
||||||
is_attention_free=False)
|
is_attention_free=False)
|
||||||
attn_backend = backend_name_to_enum(attn_backend.get_name())
|
backend = backend_name_to_enum(attn_backend.get_name())
|
||||||
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||||
attn_backend = _Backend.XFORMERS
|
backend = _Backend.XFORMERS
|
||||||
|
|
||||||
self.attn_backend = attn_backend if attn_backend in {
|
self.attn_backend = backend if backend in {
|
||||||
_Backend.TORCH_SDPA, _Backend.XFORMERS
|
_Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||||
} else _Backend.TORCH_SDPA
|
} else _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
|||||||
@ -14,16 +14,18 @@ from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
|
||||||
|
"""
|
||||||
|
Convert a string backend name to a _Backend enum value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* _Backend: enum value if backend_name is a valid in-tree type
|
||||||
|
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
|
||||||
|
loaded.
|
||||||
|
"""
|
||||||
assert backend_name is not None
|
assert backend_name is not None
|
||||||
|
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
|
||||||
backend_members = _Backend.__members__
|
None
|
||||||
if backend_name not in backend_members:
|
|
||||||
raise ValueError(f"Invalid attention backend '{backend_name}'. "
|
|
||||||
f"Available backends: {', '.join(backend_members)} "
|
|
||||||
"(case-sensitive).")
|
|
||||||
|
|
||||||
return _Backend[backend_name]
|
|
||||||
|
|
||||||
|
|
||||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user