mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 09:15:49 +08:00
[V0 deprecation] Remove _VLLM_V1 suffixes from attention backend names (#25489)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
9659b7e78f
commit
a355561291
@ -35,7 +35,7 @@ docker run \
|
|||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||||
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||||
cd tests
|
cd tests
|
||||||
pytest -v -s v1/core
|
pytest -v -s v1/core
|
||||||
pytest -v -s v1/engine
|
pytest -v -s v1/engine
|
||||||
|
|||||||
@ -103,7 +103,7 @@ backend_configs = {
|
|||||||
# Triton Attention
|
# Triton Attention
|
||||||
"TritonAttn":
|
"TritonAttn":
|
||||||
BackendConfig(name="TritonAttn",
|
BackendConfig(name="TritonAttn",
|
||||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL",
|
"cudagraph_mode": "FULL",
|
||||||
}),
|
}),
|
||||||
|
|||||||
@ -338,7 +338,7 @@ else:
|
|||||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||||
@pytest.mark.parametrize("backend",
|
@pytest.mark.parametrize("backend",
|
||||||
[_Backend.FLASHINFER] if current_platform.is_cuda()
|
[_Backend.FLASHINFER] if current_platform.is_cuda()
|
||||||
else [_Backend.TRITON_ATTN_VLLM_V1])
|
else [_Backend.TRITON_ATTN])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"split_attention",
|
"split_attention",
|
||||||
[False, True] if current_platform.is_rocm() else [False])
|
[False, True] if current_platform.is_rocm() else [False])
|
||||||
|
|||||||
@ -68,7 +68,7 @@ def default_server_args(with_tool_parser: bool):
|
|||||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
|
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
|
||||||
default_server_args: list[str]):
|
default_server_args: list[str]):
|
||||||
with monkeypatch_module.context() as m:
|
with monkeypatch_module.context() as m:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
|
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
|
||||||
default_server_args) as remote_server:
|
default_server_args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|||||||
@ -31,7 +31,7 @@ DEVICE_MLA_BACKENDS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DEVICE_REGULAR_ATTN_BACKENDS = {
|
DEVICE_REGULAR_ATTN_BACKENDS = {
|
||||||
"cuda": ["XFORMERS", "FLASHINFER"],
|
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
|
||||||
"hip": ["ROCM_FLASH"],
|
"hip": ["ROCM_FLASH"],
|
||||||
"cpu": ["TORCH_SDPA"],
|
"cpu": ["TORCH_SDPA"],
|
||||||
}
|
}
|
||||||
@ -86,7 +86,7 @@ def test_env(
|
|||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert backend.get_name() == "TORCH_SDPA"
|
||||||
|
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
@ -125,7 +125,7 @@ def test_env(
|
|||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = f"{name}_VLLM_V1"
|
expected = name
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
@ -133,7 +133,7 @@ def test_env(
|
|||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "TRITON_ATTN_VLLM_V1"
|
expected = "TRITON_ATTN"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
|
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
@ -160,7 +160,7 @@ def test_env(
|
|||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "CUTLASS_MLA_VLLM_V1"
|
expected = "CUTLASS_MLA"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
elif name == "FLASHINFER_MLA":
|
elif name == "FLASHINFER_MLA":
|
||||||
if block_size not in [32, 64]:
|
if block_size not in [32, 64]:
|
||||||
@ -193,7 +193,7 @@ def test_env(
|
|||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = f"{name}_VLLM_V1"
|
expected = name
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
elif name == "FLASH_ATTN_MLA":
|
elif name == "FLASH_ATTN_MLA":
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
@ -210,7 +210,7 @@ def test_env(
|
|||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "TRITON_MLA_VLLM_V1"
|
expected = "TRITON_MLA"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
elif name == "FLASHINFER":
|
elif name == "FLASHINFER":
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
@ -218,25 +218,24 @@ def test_env(
|
|||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "FLASHINFER_VLLM_V1"
|
expected = "FLASHINFER"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
else:
|
elif name == "XFORMERS":
|
||||||
backend = get_attn_backend(32,
|
backend = get_attn_backend(32,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "FLASH_ATTN_VLLM_V1"
|
expected = "XFORMERS"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
|
elif name == "FLASH_ATTN":
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(32,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
assert backend.get_name() == "FLEX_ATTENTION", (
|
expected = "FLASH_ATTN"
|
||||||
"Should fallback to FlexAttention if head size is "
|
assert backend.get_name() == expected
|
||||||
"not supported by FlashAttention")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||||
@ -252,7 +251,7 @@ def test_fp32_fallback(
|
|||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert backend.get_name() == "TORCH_SDPA"
|
||||||
|
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
@ -266,6 +265,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
||||||
# get_attn_backend
|
# get_attn_backend
|
||||||
|
|
||||||
|
pytest.skip("Skipping as current backend selector does not " \
|
||||||
|
"handle fallbacks when a backend is set via env var.")
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
|
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# Test standard ROCm attention
|
# Test standard ROCm attention
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||||
assert (backend.get_name() == "ROCM_FLASH"
|
assert (backend.get_name() == "ROCM_FLASH"
|
||||||
or backend.get_name() == "TRITON_ATTN_VLLM_V1")
|
or backend.get_name() == "TRITON_ATTN")
|
||||||
|
|
||||||
# MLA test for deepseek related
|
# MLA test for deepseek related
|
||||||
|
|
||||||
@ -40,8 +40,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
|||||||
16,
|
16,
|
||||||
False,
|
False,
|
||||||
use_mla=True)
|
use_mla=True)
|
||||||
assert (backend.get_name() == "TRITON_MLA"
|
assert backend.get_name() == "TRITON_MLA"
|
||||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
|
||||||
|
|
||||||
# If attention backend is None
|
# If attention backend is None
|
||||||
# If use_mla is true
|
# If use_mla is true
|
||||||
@ -53,8 +52,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
|||||||
16,
|
16,
|
||||||
False,
|
False,
|
||||||
use_mla=True)
|
use_mla=True)
|
||||||
assert (backend.get_name() == "TRITON_MLA"
|
assert backend.get_name() == "TRITON_MLA"
|
||||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
|
||||||
|
|
||||||
# change the attention backend to AITER MLA
|
# change the attention backend to AITER MLA
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||||
@ -64,8 +62,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
|||||||
1,
|
1,
|
||||||
False,
|
False,
|
||||||
use_mla=True)
|
use_mla=True)
|
||||||
assert (backend.get_name() == "ROCM_AITER_MLA"
|
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||||
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
|
|
||||||
|
|
||||||
# If attention backend is None
|
# If attention backend is None
|
||||||
# If use_mla is true
|
# If use_mla is true
|
||||||
@ -79,5 +76,4 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
|||||||
1,
|
1,
|
||||||
False,
|
False,
|
||||||
use_mla=True)
|
use_mla=True)
|
||||||
assert (backend.get_name() == "ROCM_AITER_MLA"
|
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||||
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
|
|
||||||
|
|||||||
@ -524,14 +524,14 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
|||||||
|
|
||||||
* Backend instance
|
* Backend instance
|
||||||
'''
|
'''
|
||||||
if backend_name in (STR_XFORMERS_ATTN_VAL, "XFORMERS_VLLM_V1"):
|
if backend_name == STR_XFORMERS_ATTN_VAL:
|
||||||
from vllm.v1.attention.backends.xformers import (
|
from vllm.v1.attention.backends.xformers import (
|
||||||
XFormersAttentionBackend)
|
XFormersAttentionBackend)
|
||||||
return XFormersAttentionBackend()
|
return XFormersAttentionBackend()
|
||||||
if backend_name in (STR_FLASH_ATTN_VAL, "FLASH_ATTN_VLLM_V1"):
|
if backend_name == STR_FLASH_ATTN_VAL:
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||||
return FlashAttentionBackend()
|
return FlashAttentionBackend()
|
||||||
if backend_name == "TRITON_ATTN_VLLM_V1":
|
if backend_name == "TRITON_ATTN":
|
||||||
from vllm.v1.attention.backends.triton_attn import (
|
from vllm.v1.attention.backends.triton_attn import (
|
||||||
TritonAttentionBackend)
|
TritonAttentionBackend)
|
||||||
return TritonAttentionBackend()
|
return TritonAttentionBackend()
|
||||||
@ -539,7 +539,7 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
|||||||
from vllm.v1.attention.backends.flex_attention import (
|
from vllm.v1.attention.backends.flex_attention import (
|
||||||
FlexAttentionBackend)
|
FlexAttentionBackend)
|
||||||
return FlexAttentionBackend()
|
return FlexAttentionBackend()
|
||||||
if backend_name in ("TORCH_SDPA", "TORCH_SDPA_VLLM_V1"):
|
if backend_name == "TORCH_SDPA":
|
||||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
|
from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
|
||||||
return TorchSDPABackend()
|
return TorchSDPABackend()
|
||||||
if backend_name == "FLASHINFER":
|
if backend_name == "FLASHINFER":
|
||||||
|
|||||||
@ -84,7 +84,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
|||||||
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
|
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
|
||||||
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
||||||
# L4 supports FA3.
|
# L4 supports FA3.
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||||
if model_arch == "WhisperForConditionalGeneration":
|
if model_arch == "WhisperForConditionalGeneration":
|
||||||
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||||
LLM(
|
LLM(
|
||||||
|
|||||||
@ -1131,14 +1131,14 @@ def has_module_attribute(module_name, attribute_name):
|
|||||||
|
|
||||||
def get_attn_backend_list_based_on_platform() -> list[str]:
|
def get_attn_backend_list_based_on_platform() -> list[str]:
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"]
|
return ["FLASH_ATTN", "TRITON_ATTN", "TREE_ATTN"]
|
||||||
elif current_platform.is_rocm():
|
elif current_platform.is_rocm():
|
||||||
attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
|
attn_backend_list = ["TRITON_ATTN"]
|
||||||
try:
|
try:
|
||||||
import aiter # noqa: F401
|
import aiter # noqa: F401
|
||||||
attn_backend_list.append("FLASH_ATTN_VLLM_V1")
|
attn_backend_list.append("FLASH_ATTN")
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")
|
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
|
||||||
|
|
||||||
return attn_backend_list
|
return attn_backend_list
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -21,16 +21,15 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
|||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
|
||||||
BACKENDS_TO_TEST = [
|
BACKENDS_TO_TEST = [
|
||||||
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
|
_Backend.FLASH_ATTN, _Backend.FLASHINFER, _Backend.FLEX_ATTENTION,
|
||||||
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN,
|
_Backend.TRITON_ATTN, _Backend.TREE_ATTN, "FLEX_ATTENTION_SLOW"
|
||||||
"FLEX_ATTENTION_SLOW"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove flashinfer from the list if it's not available
|
# Remove flashinfer from the list if it's not available
|
||||||
try:
|
try:
|
||||||
import flashinfer # noqa: F401
|
import flashinfer # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_VLLM_V1)
|
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER)
|
||||||
|
|
||||||
|
|
||||||
def _convert_dtype_to_torch(dtype):
|
def _convert_dtype_to_torch(dtype):
|
||||||
@ -214,7 +213,7 @@ def run_attention_backend(
|
|||||||
builder_cls, impl_cls = get_attention_backend(actual_backend)
|
builder_cls, impl_cls = get_attention_backend(actual_backend)
|
||||||
|
|
||||||
# Mock flashinfer's get_per_layer_parameters if needed
|
# Mock flashinfer's get_per_layer_parameters if needed
|
||||||
if actual_backend == _Backend.FLASHINFER_VLLM_V1:
|
if actual_backend == _Backend.FLASHINFER:
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
|
|
||||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||||
@ -434,7 +433,7 @@ def _test_backend_correctness(
|
|||||||
# [num_blocks, 2, block_size, num_kv_heads, head_size]
|
# [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||||
# Select the appropriate KV cache format for each backend
|
# Select the appropriate KV cache format for each backend
|
||||||
kv_cache_for_backend = kv_cache
|
kv_cache_for_backend = kv_cache
|
||||||
if backend_name == _Backend.FLASHINFER_VLLM_V1:
|
if backend_name == _Backend.FLASHINFER:
|
||||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||||
|
|
||||||
# For FlashInfer default to HND layout and
|
# For FlashInfer default to HND layout and
|
||||||
@ -518,8 +517,8 @@ def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
|
|
||||||
|
|
||||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||||
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION,
|
_Backend.FLASH_ATTN, _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN,
|
||||||
_Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW"
|
"FLEX_ATTENTION_SLOW"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,8 +15,8 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
|||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
|
||||||
BACKENDS_TO_TEST = [
|
BACKENDS_TO_TEST = [
|
||||||
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA,
|
_Backend.CUTLASS_MLA, _Backend.FLASHMLA, _Backend.FLASH_ATTN_MLA,
|
||||||
_Backend.TRITON_MLA_VLLM_V1
|
_Backend.TRITON_MLA
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove CUTLASS_MLA from the list if not using sm100
|
# Remove CUTLASS_MLA from the list if not using sm100
|
||||||
|
|||||||
@ -120,30 +120,30 @@ def get_attention_backend(backend_name: _Backend):
|
|||||||
Tuple of (backend_builder_class, backend_impl_class)
|
Tuple of (backend_builder_class, backend_impl_class)
|
||||||
"""
|
"""
|
||||||
backend_map = {
|
backend_map = {
|
||||||
_Backend.FLASH_ATTN_VLLM_V1:
|
_Backend.FLASH_ATTN:
|
||||||
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
if current_platform.is_cuda() else
|
if current_platform.is_cuda() else
|
||||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||||
),
|
),
|
||||||
_Backend.FLASHINFER_VLLM_V1:
|
_Backend.FLASHINFER:
|
||||||
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
|
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
|
||||||
_Backend.FLEX_ATTENTION:
|
_Backend.FLEX_ATTENTION:
|
||||||
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
|
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
|
||||||
_Backend.TRITON_ATTN_VLLM_V1:
|
_Backend.TRITON_ATTN:
|
||||||
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
|
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
|
||||||
_Backend.TREE_ATTN:
|
_Backend.TREE_ATTN:
|
||||||
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
|
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
|
||||||
_Backend.XFORMERS_VLLM_V1:
|
_Backend.XFORMERS:
|
||||||
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
|
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
|
||||||
_Backend.CUTLASS_MLA:
|
_Backend.CUTLASS_MLA:
|
||||||
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
|
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
|
||||||
_Backend.FLASHMLA_VLLM_V1:
|
_Backend.FLASHMLA:
|
||||||
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
||||||
_Backend.FLASH_ATTN_MLA:
|
_Backend.FLASH_ATTN_MLA:
|
||||||
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
||||||
_Backend.FLASHINFER_MLA:
|
_Backend.FLASHINFER_MLA:
|
||||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
|
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
|
||||||
_Backend.TRITON_MLA_VLLM_V1:
|
_Backend.TRITON_MLA:
|
||||||
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -89,7 +89,7 @@ backend_configs = {
|
|||||||
# Triton Attention
|
# Triton Attention
|
||||||
"TritonAttn":
|
"TritonAttn":
|
||||||
BackendConfig(name="TritonAttn",
|
BackendConfig(name="TritonAttn",
|
||||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
}),
|
}),
|
||||||
|
|||||||
@ -9,11 +9,14 @@ from ...utils import create_new_process_for_each_test
|
|||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@pytest.mark.parametrize("attn_backend",
|
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||||
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"])
|
|
||||||
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
||||||
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
||||||
|
|
||||||
|
if attn_backend == "FLASHINFER":
|
||||||
|
pytest.skip("This test is failing with FlashInfer backend and "
|
||||||
|
"needs investigation. See issue #25679.")
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|||||||
@ -176,12 +176,11 @@ def test_eagle_correctness(
|
|||||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||||
and not current_platform.is_rocm()):
|
pytest.skip("TRITON_ATTN does not support "
|
||||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
|
||||||
"multi-token eagle spec decode on current platform")
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
method, model_name, spec_model_name, tp_size = model_setup
|
method, model_name, spec_model_name, tp_size = model_setup
|
||||||
|
|||||||
@ -314,12 +314,11 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
|||||||
|
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||||
and not current_platform.is_rocm()):
|
pytest.skip("TRITON_ATTN does not support "
|
||||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
|
||||||
"multi-token eagle spec decode on current platform")
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# Setup draft model mock
|
# Setup draft model mock
|
||||||
@ -400,16 +399,15 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
|
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||||
and not current_platform.is_rocm()):
|
pytest.skip("TRITON_ATTN does not support "
|
||||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
|
||||||
"multi-token eagle spec decode on current platform")
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
if (attn_backend == "TREE_ATTN"):
|
if (attn_backend == "TREE_ATTN"):
|
||||||
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
|
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
|
||||||
"because it requires special input mocking.")
|
"because it requires special input mocking.")
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# Use GPU device
|
# Use GPU device
|
||||||
@ -510,12 +508,12 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
device=device)
|
device=device)
|
||||||
sampling_metadata = mock.MagicMock()
|
sampling_metadata = mock.MagicMock()
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN_VLLM_V1":
|
if attn_backend == "FLASH_ATTN":
|
||||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||||
_Backend.FLASH_ATTN_VLLM_V1)
|
_Backend.FLASH_ATTN)
|
||||||
elif attn_backend == "TRITON_ATTN_VLLM_V1":
|
elif attn_backend == "TRITON_ATTN":
|
||||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||||
_Backend.TRITON_ATTN_VLLM_V1)
|
_Backend.TRITON_ATTN)
|
||||||
elif attn_backend == "TREE_ATTN":
|
elif attn_backend == "TREE_ATTN":
|
||||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||||
_Backend.TREE_ATTN)
|
_Backend.TREE_ATTN)
|
||||||
|
|||||||
@ -41,12 +41,11 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
|
|||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||||
and not current_platform.is_rocm()):
|
pytest.skip("TRITON_ATTN does not support "
|
||||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
|
||||||
"multi-token eagle spec decode on current platform")
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
|
|||||||
@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None:
|
|||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
slot_mapping=branch_slot_mapping,
|
slot_mapping=branch_slot_mapping,
|
||||||
seqlen_k=sequence_position + q_len,
|
seqlen_k=sequence_position + q_len,
|
||||||
backend=_Backend.FLASH_ATTN_VLLM_V1,
|
backend=_Backend.FLASH_ATTN,
|
||||||
).view(batch_size, -1, num_heads, dim_per_head)
|
).view(batch_size, -1, num_heads, dim_per_head)
|
||||||
|
|
||||||
# Compare the outputs.
|
# Compare the outputs.
|
||||||
|
|||||||
@ -54,26 +54,3 @@ def test_v1_llm_by_default(monkeypatch):
|
|||||||
print(llm.generate("Hello my name is"))
|
print(llm.generate("Hello my name is"))
|
||||||
assert hasattr(llm.llm_engine, "engine_core")
|
assert hasattr(llm.llm_engine, "engine_core")
|
||||||
m.delenv("VLLM_USE_V1")
|
m.delenv("VLLM_USE_V1")
|
||||||
|
|
||||||
|
|
||||||
def test_v1_attn_backend(monkeypatch):
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
if os.getenv("VLLM_USE_V1", None):
|
|
||||||
m.delenv("VLLM_USE_V1")
|
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
|
||||||
|
|
||||||
# Fall back to V0.
|
|
||||||
_ = AsyncEngineArgs(model=MODEL).create_engine_config()
|
|
||||||
assert not envs.VLLM_USE_V1
|
|
||||||
m.delenv("VLLM_USE_V1")
|
|
||||||
|
|
||||||
# Reject if V1.
|
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
AsyncEngineArgs(model=MODEL).create_engine_config()
|
|
||||||
m.delenv("VLLM_USE_V1")
|
|
||||||
|
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHMLA")
|
|
||||||
_ = AsyncEngineArgs(model=MODEL).create_engine_config()
|
|
||||||
assert envs.VLLM_USE_V1
|
|
||||||
m.delenv("VLLM_USE_V1")
|
|
||||||
|
|||||||
@ -364,7 +364,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
self.impl.process_weights_after_loading(act_dtype)
|
self.impl.process_weights_after_loading(act_dtype)
|
||||||
|
|
||||||
# FlashInfer requires attention sinks to be float32
|
# FlashInfer requires attention sinks to be float32
|
||||||
if (self.backend == _Backend.FLASHINFER_VLLM_V1
|
if (self.backend == _Backend.FLASHINFER
|
||||||
and hasattr(self.impl, 'sinks')):
|
and hasattr(self.impl, 'sinks')):
|
||||||
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||||
assert isinstance(self.impl, FlashInferImpl)
|
assert isinstance(self.impl, FlashInferImpl)
|
||||||
@ -420,21 +420,17 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
self.attn_backend = backend if backend in {
|
self.attn_backend = backend if backend in {
|
||||||
_Backend.TORCH_SDPA,
|
_Backend.TORCH_SDPA,
|
||||||
_Backend.TORCH_SDPA_VLLM_V1,
|
|
||||||
_Backend.XFORMERS,
|
_Backend.XFORMERS,
|
||||||
_Backend.PALLAS_VLLM_V1,
|
_Backend.PALLAS,
|
||||||
_Backend.ROCM_AITER_FA,
|
_Backend.ROCM_AITER_FA,
|
||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.FLASH_ATTN_VLLM_V1,
|
|
||||||
} else _Backend.TORCH_SDPA
|
} else _Backend.TORCH_SDPA
|
||||||
|
|
||||||
if (self.attn_backend == _Backend.XFORMERS
|
if (self.attn_backend == _Backend.XFORMERS
|
||||||
and not check_xformers_availability()):
|
and not check_xformers_availability()):
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
|
||||||
if self.attn_backend in {
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
|
|
||||||
}:
|
|
||||||
if use_upstream_fa:
|
if use_upstream_fa:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
@ -468,11 +464,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||||
|
|
||||||
if self.attn_backend in {
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
_Backend.FLASH_ATTN,
|
|
||||||
_Backend.FLASH_ATTN_VLLM_V1,
|
|
||||||
}:
|
|
||||||
|
|
||||||
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||||
step=q_len,
|
step=q_len,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -499,8 +491,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
scale=self.scale)
|
scale=self.scale)
|
||||||
elif (self.attn_backend == _Backend.TORCH_SDPA
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1):
|
|
||||||
query, key, value = (x.transpose(1, 2)
|
query, key, value = (x.transpose(1, 2)
|
||||||
for x in (query, key, value))
|
for x in (query, key, value))
|
||||||
out = F.scaled_dot_product_attention(query,
|
out = F.scaled_dot_product_attention(query,
|
||||||
@ -508,7 +499,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
value,
|
value,
|
||||||
scale=self.scale)
|
scale=self.scale)
|
||||||
out = out.transpose(1, 2)
|
out = out.transpose(1, 2)
|
||||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
elif self.attn_backend == _Backend.PALLAS:
|
||||||
query, key, value = (x.transpose(1, 2)
|
query, key, value = (x.transpose(1, 2)
|
||||||
for x in (query, key, value))
|
for x in (query, key, value))
|
||||||
from torch_xla.experimental.custom_kernel import flash_attention
|
from torch_xla.experimental.custom_kernel import flash_attention
|
||||||
|
|||||||
@ -186,6 +186,14 @@ def _cached_get_attn_backend(
|
|||||||
# Check the environment variable and override if specified
|
# Check the environment variable and override if specified
|
||||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||||
if backend_by_env_var is not None:
|
if backend_by_env_var is not None:
|
||||||
|
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||||
|
logger.warning(
|
||||||
|
"The suffix '_VLLM_V1' in the environment variable "
|
||||||
|
"%s is no longer necessary as V0 backends have been "
|
||||||
|
"deprecated. Please remove this suffix from your "
|
||||||
|
"environment variable setting.", STR_BACKEND_ENV_VAR)
|
||||||
|
backend_by_env_var = backend_by_env_var.removesuffix(
|
||||||
|
"_VLLM_V1")
|
||||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
if selected_backend is None:
|
if selected_backend is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -577,8 +577,8 @@ class NixlConnectorWorker:
|
|||||||
use_mla=self.use_mla)
|
use_mla=self.use_mla)
|
||||||
self.backend_name = backend.get_name()
|
self.backend_name = backend.get_name()
|
||||||
attn_backend = backend_name_to_enum(self.backend_name)
|
attn_backend = backend_name_to_enum(self.backend_name)
|
||||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
||||||
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
|
self._use_pallas = attn_backend == _Backend.PALLAS
|
||||||
self.kv_cache_layout = get_kv_cache_layout()
|
self.kv_cache_layout = get_kv_cache_layout()
|
||||||
logger.debug("Detected attention backend %s", self.backend_name)
|
logger.debug("Detected attention backend %s", self.backend_name)
|
||||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||||
@ -749,7 +749,7 @@ class NixlConnectorWorker:
|
|||||||
# (roughly 8KB vs 5KB).
|
# (roughly 8KB vs 5KB).
|
||||||
# Conversely for FlashInfer, K and V are registered in the same region
|
# Conversely for FlashInfer, K and V are registered in the same region
|
||||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||||
split_k_and_v = not (self.use_mla or self._use_pallas_v1
|
split_k_and_v = not (self.use_mla or self._use_pallas
|
||||||
or self._use_flashinfer)
|
or self._use_flashinfer)
|
||||||
tensor_size_bytes = None
|
tensor_size_bytes = None
|
||||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||||
@ -938,7 +938,7 @@ class NixlConnectorWorker:
|
|||||||
tp_ratio = divide(self._tp_size[self.engine_id],
|
tp_ratio = divide(self._tp_size[self.engine_id],
|
||||||
self._tp_size[engine_id])
|
self._tp_size[engine_id])
|
||||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||||
assert not self._use_pallas_v1 or tp_ratio == 1, \
|
assert not self._use_pallas or tp_ratio == 1, \
|
||||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||||
|
|
||||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||||
|
|||||||
@ -1479,25 +1479,21 @@ class EngineArgs:
|
|||||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
||||||
|
|
||||||
V1_BACKENDS = [
|
V1_BACKENDS = [
|
||||||
"FLASH_ATTN_VLLM_V1",
|
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
"PALLAS",
|
"PALLAS",
|
||||||
"PALLAS_VLLM_V1",
|
"TRITON_ATTN",
|
||||||
"TRITON_ATTN_VLLM_V1",
|
|
||||||
"TRITON_MLA",
|
"TRITON_MLA",
|
||||||
"CUTLASS_MLA",
|
"CUTLASS_MLA",
|
||||||
"FLASHMLA",
|
"FLASHMLA",
|
||||||
"FLASHMLA_VLLM_V1",
|
|
||||||
"FLASH_ATTN_MLA",
|
"FLASH_ATTN_MLA",
|
||||||
"FLASHINFER",
|
"FLASHINFER",
|
||||||
"FLASHINFER_VLLM_V1",
|
|
||||||
"FLASHINFER_MLA",
|
"FLASHINFER_MLA",
|
||||||
"ROCM_AITER_MLA",
|
"ROCM_AITER_MLA",
|
||||||
"TORCH_SDPA_VLLM_V1",
|
"TORCH_SDPA",
|
||||||
"FLEX_ATTENTION",
|
"FLEX_ATTENTION",
|
||||||
"TREE_ATTN",
|
"TREE_ATTN",
|
||||||
"XFORMERS_VLLM_V1",
|
"XFORMERS",
|
||||||
"ROCM_ATTN_VLLM_V1",
|
"ROCM_ATTN",
|
||||||
]
|
]
|
||||||
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||||
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def kernel_warmup(worker: "Worker"):
|
|||||||
# and is not a pooling model
|
# and is not a pooling model
|
||||||
def _is_flashinfer_backend(backend):
|
def _is_flashinfer_backend(backend):
|
||||||
try:
|
try:
|
||||||
return backend.get_name() == "FLASHINFER_VLLM_V1"
|
return backend.get_name() == "FLASHINFER"
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -241,9 +241,8 @@ class CudaPlatformBase(Platform):
|
|||||||
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
|
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
|
||||||
selected_backend is None and cls.is_device_capability(100)
|
selected_backend is None and cls.is_device_capability(100)
|
||||||
and block_size in [32, 64])
|
and block_size in [32, 64])
|
||||||
use_flashmla = selected_backend in [
|
use_flashmla = selected_backend == _Backend.FLASHMLA or (
|
||||||
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
|
selected_backend is None and is_flashmla_supported()[0])
|
||||||
] or (selected_backend is None and is_flashmla_supported()[0])
|
|
||||||
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
|
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
|
||||||
selected_backend is None and flash_attn_supports_mla())
|
selected_backend is None and flash_attn_supports_mla())
|
||||||
use_triton = selected_backend == _Backend.TRITON_MLA or (
|
use_triton = selected_backend == _Backend.TRITON_MLA or (
|
||||||
@ -282,7 +281,7 @@ class CudaPlatformBase(Platform):
|
|||||||
if use_v1:
|
if use_v1:
|
||||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||||
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||||
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
|
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
|
||||||
@ -300,16 +299,16 @@ class CudaPlatformBase(Platform):
|
|||||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||||
return FLEX_ATTENTION_V1
|
return FLEX_ATTENTION_V1
|
||||||
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
elif selected_backend == _Backend.TRITON_ATTN:
|
||||||
logger.info_once("Using Triton backend on V1 engine.")
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
return TRITON_ATTN_VLLM_V1
|
return TRITON_ATTN
|
||||||
elif selected_backend == _Backend.FLASH_ATTN:
|
elif selected_backend == _Backend.FLASH_ATTN:
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
return FLASH_ATTN_V1
|
return FLASH_ATTN_V1
|
||||||
elif selected_backend == _Backend.TREE_ATTN:
|
elif selected_backend == _Backend.TREE_ATTN:
|
||||||
logger.info_once("Using Tree Attention backend on V1 engine.")
|
logger.info_once("Using Tree Attention backend on V1 engine.")
|
||||||
return TREE_ATTN_V1
|
return TREE_ATTN_V1
|
||||||
elif selected_backend == _Backend.XFORMERS_VLLM_V1:
|
elif selected_backend == _Backend.XFORMERS:
|
||||||
logger.info_once("Using XFormers backend on V1 engine.")
|
logger.info_once("Using XFormers backend on V1 engine.")
|
||||||
return XFORMERS_V1
|
return XFORMERS_V1
|
||||||
|
|
||||||
@ -341,7 +340,7 @@ class CudaPlatformBase(Platform):
|
|||||||
if (has_sink or
|
if (has_sink or
|
||||||
use_fp8_kv_cache) and not cls.is_device_capability(90):
|
use_fp8_kv_cache) and not cls.is_device_capability(90):
|
||||||
logger.info_once("Using Triton backend on V1 engine.")
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
return TRITON_ATTN_VLLM_V1
|
return TRITON_ATTN
|
||||||
elif is_default_backend_supported := is_attn_backend_supported(
|
elif is_default_backend_supported := is_attn_backend_supported(
|
||||||
FLASH_ATTN_V1, head_size, dtype,
|
FLASH_ATTN_V1, head_size, dtype,
|
||||||
allow_import_error=False):
|
allow_import_error=False):
|
||||||
@ -457,12 +456,12 @@ class CudaPlatformBase(Platform):
|
|||||||
else:
|
else:
|
||||||
# Default to FlashAttention
|
# Default to FlashAttention
|
||||||
if attention_backend is None:
|
if attention_backend is None:
|
||||||
attention_backend = "FLASH_ATTN_VLLM_V1"
|
attention_backend = "FLASH_ATTN"
|
||||||
|
|
||||||
# All Blackwell backends support fp8
|
# All Blackwell backends support fp8
|
||||||
if cls.is_device_capability(100):
|
if cls.is_device_capability(100):
|
||||||
supported = True
|
supported = True
|
||||||
elif attention_backend == "FLASH_ATTN_VLLM_V1":
|
elif attention_backend == "FLASH_ATTN":
|
||||||
if fp8_attention:
|
if fp8_attention:
|
||||||
from vllm.attention.utils.fa_utils import (
|
from vllm.attention.utils.fa_utils import (
|
||||||
flash_attn_supports_fp8)
|
flash_attn_supports_fp8)
|
||||||
@ -471,7 +470,7 @@ class CudaPlatformBase(Platform):
|
|||||||
supported = True
|
supported = True
|
||||||
elif attention_backend == "FLASHINFER":
|
elif attention_backend == "FLASHINFER":
|
||||||
supported = True
|
supported = True
|
||||||
elif attention_backend == "TRITON_ATTN_VLLM_V1":
|
elif attention_backend == "TRITON_ATTN":
|
||||||
supported = cls.supports_fp8()
|
supported = cls.supports_fp8()
|
||||||
return supported
|
return supported
|
||||||
|
|
||||||
|
|||||||
@ -40,34 +40,26 @@ def in_wsl() -> bool:
|
|||||||
|
|
||||||
class _Backend(enum.Enum):
|
class _Backend(enum.Enum):
|
||||||
FLASH_ATTN = enum.auto()
|
FLASH_ATTN = enum.auto()
|
||||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
TRITON_ATTN = enum.auto()
|
||||||
TRITON_ATTN_VLLM_V1 = enum.auto()
|
|
||||||
XFORMERS = enum.auto()
|
XFORMERS = enum.auto()
|
||||||
ROCM_FLASH = enum.auto()
|
ROCM_FLASH = enum.auto()
|
||||||
ROCM_AITER_MLA = enum.auto() # Supported by V1
|
ROCM_AITER_MLA = enum.auto() # Supported by V1
|
||||||
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
|
|
||||||
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
||||||
TORCH_SDPA = enum.auto()
|
TORCH_SDPA = enum.auto()
|
||||||
TORCH_SDPA_VLLM_V1 = enum.auto()
|
|
||||||
FLASHINFER = enum.auto()
|
FLASHINFER = enum.auto()
|
||||||
FLASHINFER_VLLM_V1 = enum.auto()
|
|
||||||
FLASHINFER_MLA = enum.auto()
|
FLASHINFER_MLA = enum.auto()
|
||||||
TRITON_MLA = enum.auto() # Supported by V1
|
TRITON_MLA = enum.auto() # Supported by V1
|
||||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
|
||||||
CUTLASS_MLA = enum.auto()
|
CUTLASS_MLA = enum.auto()
|
||||||
FLASHMLA = enum.auto() # Supported by V1
|
FLASHMLA = enum.auto() # Supported by V1
|
||||||
FLASHMLA_VLLM_V1 = enum.auto()
|
|
||||||
FLASH_ATTN_MLA = enum.auto() # Supported by V1
|
FLASH_ATTN_MLA = enum.auto() # Supported by V1
|
||||||
PALLAS = enum.auto()
|
PALLAS = enum.auto()
|
||||||
PALLAS_VLLM_V1 = enum.auto()
|
|
||||||
IPEX = enum.auto()
|
IPEX = enum.auto()
|
||||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||||
NO_ATTENTION = enum.auto()
|
NO_ATTENTION = enum.auto()
|
||||||
FLEX_ATTENTION = enum.auto()
|
FLEX_ATTENTION = enum.auto()
|
||||||
TREE_ATTN = enum.auto()
|
TREE_ATTN = enum.auto()
|
||||||
XFORMERS_VLLM_V1 = enum.auto()
|
ROCM_ATTN = enum.auto()
|
||||||
ROCM_ATTN_VLLM_V1 = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class PlatformEnum(enum.Enum):
|
class PlatformEnum(enum.Enum):
|
||||||
|
|||||||
@ -218,8 +218,7 @@ class RocmPlatform(Platform):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f" The selected backend, {selected_backend.name},"
|
f" The selected backend, {selected_backend.name},"
|
||||||
f"does not support block size {block_size}.")
|
f"does not support block size {block_size}.")
|
||||||
if selected_backend in (_Backend.ROCM_AITER_MLA,
|
if selected_backend == _Backend.ROCM_AITER_MLA:
|
||||||
_Backend.ROCM_AITER_MLA_VLLM_V1):
|
|
||||||
if block_size == 1:
|
if block_size == 1:
|
||||||
logger.info("Using AITER MLA backend on V1 engine.")
|
logger.info("Using AITER MLA backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||||
@ -240,7 +239,7 @@ class RocmPlatform(Platform):
|
|||||||
elif (envs.VLLM_ROCM_USE_AITER and
|
elif (envs.VLLM_ROCM_USE_AITER and
|
||||||
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
|
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
|
||||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
|
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
|
||||||
selected_backend == _Backend.ROCM_ATTN_VLLM_V1:
|
selected_backend == _Backend.ROCM_ATTN:
|
||||||
# rocm specific backend, with aiter and/or
|
# rocm specific backend, with aiter and/or
|
||||||
# triton prefix-prefill
|
# triton prefix-prefill
|
||||||
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
|
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
|
||||||
|
|||||||
@ -50,8 +50,7 @@ class TpuPlatform(Platform):
|
|||||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
block_size: int, use_v1: bool, use_mla: bool,
|
block_size: int, use_v1: bool, use_mla: bool,
|
||||||
has_sink) -> str:
|
has_sink) -> str:
|
||||||
if (selected_backend != _Backend.PALLAS
|
if selected_backend != _Backend.PALLAS:
|
||||||
and selected_backend != _Backend.PALLAS_VLLM_V1):
|
|
||||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||||
|
|
||||||
if not use_v1:
|
if not use_v1:
|
||||||
|
|||||||
@ -40,14 +40,14 @@ class XPUPlatform(Platform):
|
|||||||
use_v1 = envs.VLLM_USE_V1
|
use_v1 = envs.VLLM_USE_V1
|
||||||
if not use_v1:
|
if not use_v1:
|
||||||
raise ValueError("XPU backend only supports V1.")
|
raise ValueError("XPU backend only supports V1.")
|
||||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||||
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
if selected_backend == _Backend.TRITON_ATTN:
|
||||||
logger.info_once("Using Triton backend on V1 engine.")
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
return TRITON_ATTN_VLLM_V1
|
return TRITON_ATTN
|
||||||
elif selected_backend == _Backend.FLASH_ATTN:
|
elif selected_backend == _Backend.FLASH_ATTN:
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
return FLASH_ATTN_V1
|
return FLASH_ATTN
|
||||||
elif selected_backend:
|
elif selected_backend:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid attention backend for {cls.device_name}, "
|
f"Invalid attention backend for {cls.device_name}, "
|
||||||
@ -64,7 +64,7 @@ class XPUPlatform(Platform):
|
|||||||
XPU only support fp8 kv cache with triton backend.
|
XPU only support fp8 kv cache with triton backend.
|
||||||
"""
|
"""
|
||||||
if envs.is_set("VLLM_ATTENTION_BACKEND") and \
|
if envs.is_set("VLLM_ATTENTION_BACKEND") and \
|
||||||
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1":
|
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN":
|
||||||
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
|
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class TorchSDPABackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TORCH_SDPA_VLLM_V1"
|
return "TORCH_SDPA"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASH_ATTN_VLLM_V1"
|
return "FLASH_ATTN"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||||
|
|||||||
@ -167,7 +167,7 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHINFER_VLLM_V1"
|
return "FLASHINFER"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type[FlashInferImpl]:
|
def get_impl_cls() -> type[FlashInferImpl]:
|
||||||
|
|||||||
@ -270,7 +270,7 @@ class MLACommonBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_MLA_VLLM_V1"
|
return "TRITON_MLA"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class FlashMLABackend(MLACommonBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHMLA_VLLM_V1"
|
return "FLASHMLA"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class AiterMLABackend(MLACommonBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "ROCM_AITER_MLA_VLLM_V1"
|
return "ROCM_AITER_MLA"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class TritonMLABackend(MLACommonBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_MLA_VLLM_V1"
|
return "TRITON_MLA"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||||
|
|||||||
@ -86,7 +86,7 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "PALLAS_VLLM_V1"
|
return "PALLAS"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||||
|
|||||||
@ -340,7 +340,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASH_ATTN_VLLM_V1"
|
return "FLASH_ATTN"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
||||||
|
|||||||
@ -159,7 +159,7 @@ class RocmAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "ROCM_ATTN_VLLM_V1"
|
return "ROCM_ATTN"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class TreeAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TREE_ATTN_VLLM_V1"
|
return "TREE_ATTN"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_ATTN_VLLM_V1"
|
return "TRITON_ATTN"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||||
|
|||||||
@ -90,7 +90,7 @@ class XFormersAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "XFORMERS_VLLM_V1"
|
return "XFORMERS"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["XFormersAttentionImpl"]:
|
def get_impl_cls() -> type["XFormersAttentionImpl"]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user