[Test] Add xformer and flash attn tests (#3961)

Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
SangBin Cho 2024-04-11 12:09:50 +09:00 committed by GitHub
parent caada5e50a
commit e42df7227d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 0 deletions

View File

@ -4,6 +4,8 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import pytest
from vllm.attention.selector import VLLM_ATTENTION_BACKEND
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
@ -14,6 +16,7 @@ MODELS = [
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
@ -22,7 +25,10 @@ def test_models(
dtype: str,
max_tokens: int,
enforce_eager: bool,
attn_backend: str,
monkeypatch,
) -> None:
monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

View File

@ -1,4 +1,5 @@
import enum
import os
from functools import lru_cache
from typing import Type
@ -10,6 +11,8 @@ from vllm.utils import is_cpu, is_hip
logger = init_logger(__name__)
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"Cannot use FlashAttention backend because the flash_attn package "
"is not found. Please install it for better performance.")
return _Backend.XFORMERS
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]
# Default case.
return _Backend.FLASH_ATTN