diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1d..bd4c7ea3301be 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,6 +4,8 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import pytest +from vllm.attention.selector import VLLM_ATTENTION_BACKEND + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -14,6 +16,7 @@ MODELS = [ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -22,7 +25,10 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, + attn_backend: str, + monkeypatch, ) -> None: + monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4c699aed48d49..554e802cd5513 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,5 @@ import enum +import os from functools import lru_cache from typing import Type @@ -10,6 +11,8 @@ from vllm.utils import is_cpu, is_hip logger = init_logger(__name__) +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "Cannot use FlashAttention backend because the flash_attn package " "is not found. Please install it for better performance.") return _Backend.XFORMERS + + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var is not None: + return _Backend[backend_by_env_var] + + # Default case. return _Backend.FLASH_ATTN