diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 93bf20da4adb..bfeafaa9e27e 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -278,23 +278,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("use_v1", [True, False]) def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): - + """Test that invalid attention backend names raise ValueError.""" with monkeypatch.context() as m, patch( "vllm.attention.selector.current_platform", CudaPlatform()): m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) - # Test with head size 32 - backend = get_attn_backend(32, torch.float16, None, 16, False) - EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN" - assert backend.get_name() == EXPECTED - - # when block size == 16, backend will fall back to XFORMERS - # this behavior is not yet supported on V1. - if use_v1: - # TODO: support fallback on V1! - # https://github.com/vllm-project/vllm/issues/14524 - pass - else: - backend = get_attn_backend(16, torch.float16, None, 16, False) - assert backend.get_name() == "XFORMERS" + # Should raise ValueError for invalid backend + with pytest.raises(ValueError) as exc_info: + get_attn_backend(32, torch.float16, None, 16, False) + assert "Invalid attention backend: 'INVALID'" in str(exc_info.value) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 2e3c8638125f..596c556e54f0 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -193,6 +193,10 @@ def _cached_get_attn_backend( backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + raise ValueError( + f"Invalid attention backend: '{backend_by_env_var}'. " + f"Valid backends are: {list(_Backend.__members__.keys())}") # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls(