mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:24:54 +08:00
[Misc] Remove redundant attention var constants (#29650)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5c2b5cb422
commit
33b06a6f24
@ -11,7 +11,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -83,7 +82,7 @@ def test_env(
|
||||
):
|
||||
"""Test attention backend selection with valid device-backend pairs."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, name)
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", name)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
|
||||
if device == "cpu":
|
||||
@ -237,27 +236,27 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
|
||||
|
||||
# Unsupported CUDA arch
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Reset the monkeypatch for subsequent tests
|
||||
monkeypatch.undo()
|
||||
|
||||
# Unsupported data type
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Unsupported block size
|
||||
backend = get_attn_backend(16, torch.float16, None, 8)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# flash-attn is not installed
|
||||
import sys
|
||||
@ -265,7 +264,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
original_module = sys.modules.get("vllm_flash_attn")
|
||||
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Restore the original module if it existed
|
||||
if original_module is not None:
|
||||
@ -275,7 +274,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# Unsupported head size
|
||||
backend = get_attn_backend(17, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
@ -284,7 +283,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.context() as m,
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()),
|
||||
):
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID")
|
||||
|
||||
# Should raise ValueError for invalid backend
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
|
||||
@ -6,7 +6,6 @@ import torch
|
||||
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -18,7 +17,7 @@ def clear_cache():
|
||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
|
||||
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||
@ -30,19 +29,19 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
# MLA test for deepseek related
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
@ -50,7 +49,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
# If use_mla is true
|
||||
# If VLLM_ROCM_USE_AITER is enabled
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
@ -9,7 +9,6 @@ from numbers import Number
|
||||
from typing import Any, NamedTuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch._prims_common import TensorLikeType
|
||||
|
||||
@ -17,9 +16,6 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils import (
|
||||
STR_BACKEND_ENV_VAR,
|
||||
)
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
|
||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||
@ -217,22 +213,6 @@ def make_causal_mask(
|
||||
return mask
|
||||
|
||||
|
||||
def override_backend_env_variable(
|
||||
mpatch: pytest.MonkeyPatch, backend_name: str
|
||||
) -> None:
|
||||
"""
|
||||
Override the environment variable indicating the vLLM backend temporarily,
|
||||
using pytest monkeypatch to ensure that the env vars get
|
||||
reset once the test context exits.
|
||||
|
||||
Arguments:
|
||||
|
||||
* mpatch: pytest monkeypatch instance
|
||||
* backend_name: attention backend name to force
|
||||
"""
|
||||
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
|
||||
@ -11,7 +11,6 @@ import pytest
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
from ..utils import check_logprobs_close
|
||||
|
||||
|
||||
@ -76,7 +75,7 @@ def test_models(
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, backend)
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
@ -19,7 +19,6 @@ from vllm.attention.backends.registry import (
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -35,7 +34,7 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
|
||||
* AttentionBackendEnum value if an override is specified
|
||||
* None otherwise
|
||||
"""
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
|
||||
if backend_name is None:
|
||||
return None
|
||||
if backend_name == "XFORMERS":
|
||||
@ -139,10 +138,10 @@ def _cached_get_attn_backend(
|
||||
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||
logger.warning(
|
||||
"The suffix '_VLLM_V1' in the environment variable "
|
||||
"%s is no longer necessary as V0 backends have been "
|
||||
"deprecated. Please remove this suffix from your "
|
||||
"VLLM_ATTENTION_BACKEND is no longer necessary as "
|
||||
"V0 backends have been deprecated. "
|
||||
"Please remove this suffix from your "
|
||||
"environment variable setting.",
|
||||
STR_BACKEND_ENV_VAR,
|
||||
)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
||||
try:
|
||||
|
||||
@ -23,12 +23,9 @@ from vllm.model_executor.models.deepseek_v2 import (
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV3ForCausalLM,
|
||||
)
|
||||
from vllm.utils import init_logger
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepseekV2Model(nn.Module):
|
||||
|
||||
@ -7,8 +7,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
_DEPRECATED_MAPPINGS = {
|
||||
"cprofile": "profiling",
|
||||
"cprofile_context": "profiling",
|
||||
@ -37,21 +35,6 @@ def __dir__() -> list[str]:
|
||||
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Constants related to forcing the attention backend selection
|
||||
|
||||
# String name of register which may be set in order to
|
||||
# force auto-selection of attention backend by Attention
|
||||
# wrapper
|
||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
# Possible string values of STR_BACKEND_ENV_VAR
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
MASK_64_BITS = (1 << 64) - 1
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user