mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:15:01 +08:00
[UX] Fail if an invalid attention backend is specified (#22217)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
cdfd6871a5
commit
e79a12fc3a
@ -278,23 +278,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("use_v1", [True, False])
|
@pytest.mark.parametrize("use_v1", [True, False])
|
||||||
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
|
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Test that invalid attention backend names raise ValueError."""
|
||||||
with monkeypatch.context() as m, patch(
|
with monkeypatch.context() as m, patch(
|
||||||
"vllm.attention.selector.current_platform", CudaPlatform()):
|
"vllm.attention.selector.current_platform", CudaPlatform()):
|
||||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||||
|
|
||||||
# Test with head size 32
|
# Should raise ValueError for invalid backend
|
||||||
backend = get_attn_backend(32, torch.float16, None, 16, False)
|
with pytest.raises(ValueError) as exc_info:
|
||||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
|
get_attn_backend(32, torch.float16, None, 16, False)
|
||||||
assert backend.get_name() == EXPECTED
|
assert "Invalid attention backend: 'INVALID'" in str(exc_info.value)
|
||||||
|
|
||||||
# when block size == 16, backend will fall back to XFORMERS
|
|
||||||
# this behavior is not yet supported on V1.
|
|
||||||
if use_v1:
|
|
||||||
# TODO: support fallback on V1!
|
|
||||||
# https://github.com/vllm-project/vllm/issues/14524
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
|
||||||
assert backend.get_name() == "XFORMERS"
|
|
||||||
|
|||||||
@ -193,6 +193,10 @@ def _cached_get_attn_backend(
|
|||||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||||
if backend_by_env_var is not None:
|
if backend_by_env_var is not None:
|
||||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
|
if selected_backend is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||||
|
f"Valid backends are: {list(_Backend.__members__.keys())}")
|
||||||
|
|
||||||
# get device-specific attn_backend
|
# get device-specific attn_backend
|
||||||
attention_cls = current_platform.get_attn_backend_cls(
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user