[UX] Fail if an invalid attention backend is specified (#22217)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin 2025-08-05 02:54:52 -04:00 committed by GitHub
parent cdfd6871a5
commit e79a12fc3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 15 deletions

View File

@ -278,23 +278,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
@pytest.mark.parametrize("use_v1", [True, False])
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
"""Test that invalid attention backend names raise ValueError."""
with monkeypatch.context() as m, patch(
"vllm.attention.selector.current_platform", CudaPlatform()):
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
# Test with head size 32
backend = get_attn_backend(32, torch.float16, None, 16, False)
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
assert backend.get_name() == EXPECTED
# when block size == 16, backend will fall back to XFORMERS
# this behavior is not yet supported on V1.
if use_v1:
# TODO: support fallback on V1!
# https://github.com/vllm-project/vllm/issues/14524
pass
else:
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() == "XFORMERS"
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:
get_attn_backend(32, torch.float16, None, 16, False)
assert "Invalid attention backend: 'INVALID'" in str(exc_info.value)

View File

@ -193,6 +193,10 @@ def _cached_get_attn_backend(
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}")
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(