mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 07:24:25 +08:00
[Bugfix] Move current_platform import to avoid python import cache. (#16601)
Signed-off-by: iwzbi <wzbi@zju.edu.cn>
This commit is contained in:
parent
0426e3c5e1
commit
ec10fd0abc
@ -84,12 +84,12 @@ def test_env(
|
|||||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||||
assert backend.get_name() == "TORCH_SDPA"
|
assert backend.get_name() == "TORCH_SDPA"
|
||||||
|
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
with patch("vllm.platforms.current_platform", RocmPlatform()):
|
||||||
if use_mla:
|
if use_mla:
|
||||||
# ROCm MLA backend logic:
|
# ROCm MLA backend logic:
|
||||||
# - TRITON_MLA: supported when block_size != 1
|
# - TRITON_MLA: supported when block_size != 1
|
||||||
@ -126,7 +126,7 @@ def test_env(
|
|||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
|
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||||
if use_mla:
|
if use_mla:
|
||||||
# CUDA MLA backend logic:
|
# CUDA MLA backend logic:
|
||||||
# - CUTLASS_MLA: only supported with block_size == 128
|
# - CUTLASS_MLA: only supported with block_size == 128
|
||||||
@ -214,12 +214,12 @@ def test_env(
|
|||||||
def test_fp32_fallback(device: str):
|
def test_fp32_fallback(device: str):
|
||||||
"""Test attention backend selection with fp32."""
|
"""Test attention backend selection with fp32."""
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||||
assert backend.get_name() == "TORCH_SDPA"
|
assert backend.get_name() == "TORCH_SDPA"
|
||||||
|
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||||
assert backend.get_name() == "FLEX_ATTENTION"
|
assert backend.get_name() == "FLEX_ATTENTION"
|
||||||
|
|
||||||
@ -277,7 +277,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
"""Test that invalid attention backend names raise ValueError."""
|
"""Test that invalid attention backend names raise ValueError."""
|
||||||
with (
|
with (
|
||||||
monkeypatch.context() as m,
|
monkeypatch.context() as m,
|
||||||
patch("vllm.attention.selector.current_platform", CudaPlatform()),
|
patch("vllm.platforms.current_platform", CudaPlatform()),
|
||||||
):
|
):
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@ import vllm.envs as envs
|
|||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -192,6 +191,8 @@ def _cached_get_attn_backend(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get device-specific attn_backend
|
# get device-specific attn_backend
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
attention_cls = current_platform.get_attn_backend_cls(
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
selected_backend,
|
selected_backend,
|
||||||
head_size,
|
head_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user