[Bugfix]: During testing, use pytest monkeypatch for safely overriding the env var that indicates the vLLM backend (#5210)

This commit is contained in:
afeldman-nm 2024-06-03 23:32:57 -04:00 committed by GitHub
parent 3a434b07ed
commit f42a006b15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 17 deletions

View File

@ -1,21 +1,22 @@
import os
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import torch import torch
from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL,
override_backend_env_variable)
from vllm.attention.selector import which_attn_to_use from vllm.attention.selector import which_attn_to_use
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(name: str, device: str): def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable. """Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend. Note that we do not test FlashAttn because it is the default backend.
""" """
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = name override_backend_env_variable(monkeypatch, name)
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True): with patch("vllm.attention.selector.is_cpu", return_value=True):
@ -32,14 +33,11 @@ def test_env(name: str, device: str):
torch.float16, 16) torch.float16, 16)
assert backend.name == name assert backend.name == name
if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
def test_flash_attn(monkeypatch):
def test_flash_attn():
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]): with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
@ -71,14 +69,9 @@ def test_flash_attn():
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != "FLASH_ATTN"
if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
def test_invalid_env(monkeypatch):
def test_invalid_env():
"""Throw an exception if the backend name is invalid.""" """Throw an exception if the backend name is invalid."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID"
with pytest.raises(ValueError): with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup

22
tests/kernels/utils.py Normal file
View File

@ -0,0 +1,22 @@
"""Kernel test utils"""
import pytest
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
backend_name: str) -> None:
'''
Override the environment variable indicating the vLLM backend temporarily,
using pytest monkeypatch to ensure that the env vars get
reset once the test context exits.
Arguments:
* mpatch: pytest monkeypatch instance
* backend_name: attention backend name to force
'''
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)