mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:25:44 +08:00
[Misc] Remove redundant attention var constants (#29650)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5c2b5cb422
commit
33b06a6f24
@ -11,7 +11,6 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.platforms.cpu import CpuPlatform
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
from vllm.platforms.cuda import CudaPlatform
|
from vllm.platforms.cuda import CudaPlatform
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -83,7 +82,7 @@ def test_env(
|
|||||||
):
|
):
|
||||||
"""Test attention backend selection with valid device-backend pairs."""
|
"""Test attention backend selection with valid device-backend pairs."""
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, name)
|
m.setenv("VLLM_ATTENTION_BACKEND", name)
|
||||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
@ -237,27 +236,27 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
|
||||||
|
|
||||||
# Unsupported CUDA arch
|
# Unsupported CUDA arch
|
||||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
||||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != "FLASH_ATTN"
|
||||||
|
|
||||||
# Reset the monkeypatch for subsequent tests
|
# Reset the monkeypatch for subsequent tests
|
||||||
monkeypatch.undo()
|
monkeypatch.undo()
|
||||||
|
|
||||||
# Unsupported data type
|
# Unsupported data type
|
||||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != "FLASH_ATTN"
|
||||||
|
|
||||||
# Unsupported kv cache data type
|
# Unsupported kv cache data type
|
||||||
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != "FLASH_ATTN"
|
||||||
|
|
||||||
# Unsupported block size
|
# Unsupported block size
|
||||||
backend = get_attn_backend(16, torch.float16, None, 8)
|
backend = get_attn_backend(16, torch.float16, None, 8)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != "FLASH_ATTN"
|
||||||
|
|
||||||
# flash-attn is not installed
|
# flash-attn is not installed
|
||||||
import sys
|
import sys
|
||||||
@ -265,7 +264,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
original_module = sys.modules.get("vllm_flash_attn")
|
original_module = sys.modules.get("vllm_flash_attn")
|
||||||
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
|
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
|
||||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != "FLASH_ATTN"
|
||||||
|
|
||||||
# Restore the original module if it existed
|
# Restore the original module if it existed
|
||||||
if original_module is not None:
|
if original_module is not None:
|
||||||
@ -275,7 +274,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
|
|
||||||
# Unsupported head size
|
# Unsupported head size
|
||||||
backend = get_attn_backend(17, torch.float16, None, 16)
|
backend = get_attn_backend(17, torch.float16, None, 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != "FLASH_ATTN"
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||||
@ -284,7 +283,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
monkeypatch.context() as m,
|
monkeypatch.context() as m,
|
||||||
patch("vllm.platforms.current_platform", CudaPlatform()),
|
patch("vllm.platforms.current_platform", CudaPlatform()),
|
||||||
):
|
):
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID")
|
||||||
|
|
||||||
# Should raise ValueError for invalid backend
|
# Should raise ValueError for invalid backend
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -18,7 +17,7 @@ def clear_cache():
|
|||||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||||
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN")
|
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
|
||||||
|
|
||||||
# Set the current platform to ROCm using monkeypatch
|
# Set the current platform to ROCm using monkeypatch
|
||||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||||
@ -30,19 +29,19 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# MLA test for deepseek related
|
# MLA test for deepseek related
|
||||||
|
|
||||||
# change the attention backend to triton MLA
|
# change the attention backend to triton MLA
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||||
assert backend.get_name() == "TRITON_MLA"
|
assert backend.get_name() == "TRITON_MLA"
|
||||||
|
|
||||||
# If attention backend is None
|
# If attention backend is None
|
||||||
# If use_mla is true
|
# If use_mla is true
|
||||||
# The selected backend is triton MLA
|
# The selected backend is triton MLA
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||||
assert backend.get_name() == "TRITON_MLA"
|
assert backend.get_name() == "TRITON_MLA"
|
||||||
|
|
||||||
# change the attention backend to AITER MLA
|
# change the attention backend to AITER MLA
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||||
|
|
||||||
@ -50,7 +49,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# If use_mla is true
|
# If use_mla is true
|
||||||
# If VLLM_ROCM_USE_AITER is enabled
|
# If VLLM_ROCM_USE_AITER is enabled
|
||||||
# The selected backend is ROCM_AITER_MLA
|
# The selected backend is ROCM_AITER_MLA
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from numbers import Number
|
|||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
from torch._prims_common import TensorLikeType
|
from torch._prims_common import TensorLikeType
|
||||||
|
|
||||||
@ -17,9 +16,6 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul
|
|||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||||
from vllm.utils import (
|
|
||||||
STR_BACKEND_ENV_VAR,
|
|
||||||
)
|
|
||||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||||
|
|
||||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||||
@ -217,22 +213,6 @@ def make_causal_mask(
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def override_backend_env_variable(
|
|
||||||
mpatch: pytest.MonkeyPatch, backend_name: str
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Override the environment variable indicating the vLLM backend temporarily,
|
|
||||||
using pytest monkeypatch to ensure that the env vars get
|
|
||||||
reset once the test context exits.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
* mpatch: pytest monkeypatch instance
|
|
||||||
* backend_name: attention backend name to force
|
|
||||||
"""
|
|
||||||
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
|
|
||||||
|
|
||||||
|
|
||||||
def ref_masked_attention(
|
def ref_masked_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
|||||||
@ -11,7 +11,6 @@ import pytest
|
|||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
|
||||||
from ..utils import check_logprobs_close
|
from ..utils import check_logprobs_close
|
||||||
|
|
||||||
|
|
||||||
@ -76,7 +75,7 @@ def test_models(
|
|||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
|
||||||
MAX_MODEL_LEN = 1024
|
MAX_MODEL_LEN = 1024
|
||||||
NUM_LOG_PROBS = 8
|
NUM_LOG_PROBS = 8
|
||||||
|
|||||||
@ -19,7 +19,6 @@ from vllm.attention.backends.registry import (
|
|||||||
)
|
)
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -35,7 +34,7 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
|
|||||||
* AttentionBackendEnum value if an override is specified
|
* AttentionBackendEnum value if an override is specified
|
||||||
* None otherwise
|
* None otherwise
|
||||||
"""
|
"""
|
||||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
|
||||||
if backend_name is None:
|
if backend_name is None:
|
||||||
return None
|
return None
|
||||||
if backend_name == "XFORMERS":
|
if backend_name == "XFORMERS":
|
||||||
@ -139,10 +138,10 @@ def _cached_get_attn_backend(
|
|||||||
if backend_by_env_var.endswith("_VLLM_V1"):
|
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The suffix '_VLLM_V1' in the environment variable "
|
"The suffix '_VLLM_V1' in the environment variable "
|
||||||
"%s is no longer necessary as V0 backends have been "
|
"VLLM_ATTENTION_BACKEND is no longer necessary as "
|
||||||
"deprecated. Please remove this suffix from your "
|
"V0 backends have been deprecated. "
|
||||||
|
"Please remove this suffix from your "
|
||||||
"environment variable setting.",
|
"environment variable setting.",
|
||||||
STR_BACKEND_ENV_VAR,
|
|
||||||
)
|
)
|
||||||
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -23,12 +23,9 @@ from vllm.model_executor.models.deepseek_v2 import (
|
|||||||
DeepseekV2DecoderLayer,
|
DeepseekV2DecoderLayer,
|
||||||
DeepseekV3ForCausalLM,
|
DeepseekV3ForCausalLM,
|
||||||
)
|
)
|
||||||
from vllm.utils import init_logger
|
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class DeepseekV2Model(nn.Module):
|
class DeepseekV2Model(nn.Module):
|
||||||
|
|||||||
@ -7,8 +7,6 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
_DEPRECATED_MAPPINGS = {
|
_DEPRECATED_MAPPINGS = {
|
||||||
"cprofile": "profiling",
|
"cprofile": "profiling",
|
||||||
"cprofile_context": "profiling",
|
"cprofile_context": "profiling",
|
||||||
@ -37,21 +35,6 @@ def __dir__() -> list[str]:
|
|||||||
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
# Constants related to forcing the attention backend selection
|
|
||||||
|
|
||||||
# String name of register which may be set in order to
|
|
||||||
# force auto-selection of attention backend by Attention
|
|
||||||
# wrapper
|
|
||||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
|
||||||
|
|
||||||
# Possible string values of STR_BACKEND_ENV_VAR
|
|
||||||
# register, corresponding to possible backends
|
|
||||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
|
||||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
|
||||||
STR_INVALID_VAL: str = "INVALID"
|
|
||||||
|
|
||||||
MASK_64_BITS = (1 << 64) - 1
|
MASK_64_BITS = (1 << 64) - 1
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user