mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
[UX] Fail if an invalid attention backend is specified (#22217)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
cdfd6871a5
commit
e79a12fc3a
@ -278,23 +278,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
"""Test that invalid attention backend names raise ValueError."""
|
||||
with monkeypatch.context() as m, patch(
|
||||
"vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||
|
||||
# Test with head size 32
|
||||
backend = get_attn_backend(32, torch.float16, None, 16, False)
|
||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
|
||||
assert backend.get_name() == EXPECTED
|
||||
|
||||
# when block size == 16, backend will fall back to XFORMERS
|
||||
# this behavior is not yet supported on V1.
|
||||
if use_v1:
|
||||
# TODO: support fallback on V1!
|
||||
# https://github.com/vllm-project/vllm/issues/14524
|
||||
pass
|
||||
else:
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
assert backend.get_name() == "XFORMERS"
|
||||
# Should raise ValueError for invalid backend
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(32, torch.float16, None, 16, False)
|
||||
assert "Invalid attention backend: 'INVALID'" in str(exc_info.value)
|
||||
|
||||
@ -193,6 +193,10 @@ def _cached_get_attn_backend(
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}")
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user