[Misc] Remove redundant attention var constants (#29650)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-28 20:35:19 +08:00 committed by GitHub
parent 5c2b5cb422
commit 33b06a6f24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 19 additions and 63 deletions

View File

@ -11,7 +11,6 @@ from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@pytest.fixture(autouse=True)
@ -83,7 +82,7 @@ def test_env(
):
"""Test attention backend selection with valid device-backend pairs."""
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, name)
m.setenv("VLLM_ATTENTION_BACKEND", name)
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if device == "cpu":
@ -237,27 +236,27 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
)
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
# Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
assert backend.get_name() != "FLASH_ATTN"
# Reset the monkeypatch for subsequent tests
monkeypatch.undo()
# Unsupported data type
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
assert backend.get_name() != "FLASH_ATTN"
# Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, "fp8", 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
assert backend.get_name() != "FLASH_ATTN"
# Unsupported block size
backend = get_attn_backend(16, torch.float16, None, 8)
assert backend.get_name() != STR_FLASH_ATTN_VAL
assert backend.get_name() != "FLASH_ATTN"
# flash-attn is not installed
import sys
@ -265,7 +264,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
original_module = sys.modules.get("vllm_flash_attn")
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
assert backend.get_name() != "FLASH_ATTN"
# Restore the original module if it existed
if original_module is not None:
@ -275,7 +274,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
assert backend.get_name() != "FLASH_ATTN"
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
@ -284,7 +283,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
monkeypatch.context() as m,
patch("vllm.platforms.current_platform", CudaPlatform()),
):
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID")
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:

View File

@ -6,7 +6,6 @@ import torch
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_BACKEND_ENV_VAR
@pytest.fixture(autouse=True)
@ -18,7 +17,7 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN")
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
# Set the current platform to ROCm using monkeypatch
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
@ -30,19 +29,19 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# MLA test for deepseek related
# change the attention backend to triton MLA
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA"
# If attention backend is None
# If use_mla is true
# The selected backend is triton MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
m.setenv("VLLM_ATTENTION_BACKEND", "")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA"
# change the attention backend to AITER MLA
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA"
@ -50,7 +49,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# If use_mla is true
# If VLLM_ROCM_USE_AITER is enabled
# The selected backend is ROCM_AITER_MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
m.setenv("VLLM_ATTENTION_BACKEND", "")
m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA"

View File

@ -9,7 +9,6 @@ from numbers import Number
from typing import Any, NamedTuple
from unittest.mock import patch
import pytest
import torch
from torch._prims_common import TensorLikeType
@ -17,9 +16,6 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention.backends.abstract import AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils import (
STR_BACKEND_ENV_VAR,
)
from vllm.utils.torch_utils import make_tensor_with_pad
# For now, disable "test_aot_dispatch_dynamic" since there are some
@ -217,22 +213,6 @@ def make_causal_mask(
return mask
def override_backend_env_variable(
mpatch: pytest.MonkeyPatch, backend_name: str
) -> None:
"""
Override the environment variable indicating the vLLM backend temporarily,
using pytest monkeypatch to ensure that the env vars get
reset once the test context exits.
Arguments:
* mpatch: pytest monkeypatch instance
* backend_name: attention backend name to force
"""
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,

View File

@ -11,7 +11,6 @@ import pytest
from tests.quantization.utils import is_quant_method_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
from ..utils import check_logprobs_close
@ -76,7 +75,7 @@ def test_models(
with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true")
m.setenv(STR_BACKEND_ENV_VAR, backend)
m.setenv("VLLM_ATTENTION_BACKEND", backend)
MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8

View File

@ -19,7 +19,6 @@ from vllm.attention.backends.registry import (
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
@ -35,7 +34,7 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
* AttentionBackendEnum value if an override is specified
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
if backend_name is None:
return None
if backend_name == "XFORMERS":
@ -139,10 +138,10 @@ def _cached_get_attn_backend(
if backend_by_env_var.endswith("_VLLM_V1"):
logger.warning(
"The suffix '_VLLM_V1' in the environment variable "
"%s is no longer necessary as V0 backends have been "
"deprecated. Please remove this suffix from your "
"VLLM_ATTENTION_BACKEND is no longer necessary as "
"V0 backends have been deprecated. "
"Please remove this suffix from your "
"environment variable setting.",
STR_BACKEND_ENV_VAR,
)
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
try:

View File

@ -23,12 +23,9 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer,
DeepseekV3ForCausalLM,
)
from vllm.utils import init_logger
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@support_torch_compile
class DeepseekV2Model(nn.Module):

View File

@ -7,8 +7,6 @@ from typing import Any
import torch
from vllm.logger import init_logger
_DEPRECATED_MAPPINGS = {
"cprofile": "profiling",
"cprofile_context": "profiling",
@ -37,21 +35,6 @@ def __dir__() -> list[str]:
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
logger = init_logger(__name__)
# Constants related to forcing the attention backend selection
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
MASK_64_BITS = (1 << 64) - 1